diff --git a/docs/tutorials/creating_and_manipulating_qbits.ipynb b/docs/tutorials/creating_and_manipulating_qbits.ipynb index 58765cd..7a8d3dd 100644 --- a/docs/tutorials/creating_and_manipulating_qbits.ipynb +++ b/docs/tutorials/creating_and_manipulating_qbits.ipynb @@ -113,6 +113,18 @@ "qbits.draw(radius=\"nearest\")" ] }, + { + "cell_type": "markdown", + "source": "The `draw` method also supports an `interactive` mode for 3D lattices. Passing `interactive=True` renders a rotatable, zoomable Plotly figure that stays interactive in the documentation website.", + "metadata": {} + }, + { + "cell_type": "code", + "source": "qbits = qse.Qbits(cell=np.eye(3), positions=np.zeros((1, 3)))\nqbits = qbits.repeat((4, 4, 4))\nqbits.draw(radius=\"nearest\", interactive=True)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -275,4 +287,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c9b16d5..491faaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ docs = [ "ipywidgets", "ipympl", "pulser", # required for creating the docs - "qiskit" + "qiskit", + "plotly", ] [dependency-groups] diff --git a/qse/qbits.py b/qse/qbits.py index fcccb45..094036c 100644 --- a/qse/qbits.py +++ b/qse/qbits.py @@ -11,7 +11,7 @@ from qse.cell import Cell from qse.operator import Operator, Operators from qse.qbit import Qbit -from qse.vis import draw_qbits +from qse.vis import draw_3d_qbits_interactive, draw_qbits class Qbits: @@ -540,6 +540,8 @@ def draw( colouring=None, units=None, equal_aspect=True, + alpha_min=0.0, + interactive=False, ): """ Visualize the positions of a set of qubits. @@ -563,19 +565,37 @@ def draw( equal_aspect : bool, optional Whether to have the same scaling for the axes. Defaults to True. + alpha_min : float, optional + Minimum alpha for bond opacity. Bond alphas are linearly rescaled + from (alpha_min, 1), where 1 is the shortest bond and alpha_min + is the longest. Defaults to 0.0. + interactive : bool, optional + If True, render an interactive 3D Plotly figure (3D lattices only). + Requires ``plotly`` to be installed. Defaults to False. See Also -------- qse.draw """ - draw_qbits( - self, - radius=radius, - show_labels=show_labels, - colouring=colouring, - units=units, - equal_aspect=equal_aspect, - ) + if interactive: + draw_3d_qbits_interactive( + self, + radius=radius, + show_labels=show_labels, + colouring=colouring, + units=units, + alpha_min=alpha_min, + ) + else: + return draw_qbits( + self, + radius=radius, + show_labels=show_labels, + colouring=colouring, + units=units, + equal_aspect=equal_aspect, + alpha_min=alpha_min, + ) def repeat(self, rep): """Create new repeated qbits object. diff --git a/qse/vis/__init__.py b/qse/vis/__init__.py index e4352b2..75abdea 100644 --- a/qse/vis/__init__.py +++ b/qse/vis/__init__.py @@ -15,12 +15,13 @@ "bar", "draw_amp_and_det", "draw_qbits", + "draw_3d_qbits_interactive", "draw_signal", "view_matrix", "qse_green", "qse_red", ] from qse.vis.colours import qse_green, qse_red -from qse.vis.qbits import draw_qbits +from qse.vis.qbits import draw_3d_qbits_interactive, draw_qbits from qse.vis.signal import draw_amp_and_det, draw_signal from qse.vis.visualise import bar, view_matrix diff --git a/qse/vis/qbits.py b/qse/vis/qbits.py index 6736025..d4466e9 100644 --- a/qse/vis/qbits.py +++ b/qse/vis/qbits.py @@ -8,7 +8,13 @@ def draw_qbits( - qbits, radius=None, show_labels=False, colouring=None, units=None, equal_aspect=True + qbits, + radius=None, + show_labels=False, + colouring=None, + units=None, + equal_aspect=True, + alpha_min=0.0, ): """ Visualize the positions of a set of qubits. @@ -34,6 +40,10 @@ def draw_qbits( equal_aspect : bool, optional Whether to have the same scaling for the axes. Defaults to True. + alpha_min : float, optional + Minimum alpha for bond opacity. Bond alphas are linearly rescaled + from (alpha_min, 1), where 1 is the shortest bond and alpha_min + is the longest. Defaults to 0.0. """ if colouring is not None: if len(colouring) != qbits.nqbits: @@ -52,14 +62,26 @@ def draw_qbits( elif min_dist > radius: draw_bonds = False - fig = plt.figure() + figsize = (10, 8) if qbits.dim == 3 else (6.4, 4.8) + fig = plt.figure(figsize=figsize) projection = "3d" if qbits.dim == 3 else None ax = fig.add_subplot(projection=projection) if equal_aspect: ax.set_aspect("equal") if qbits.dim == 3: - _draw_3d(qbits, draw_bonds, radius, rij, min_dist, ax) + _draw_3d( + qbits, + draw_bonds, + radius, + rij, + min_dist, + alpha_min, + units, + colouring, + show_labels, + ax, + ) else: _draw_2d( qbits, @@ -67,6 +89,7 @@ def draw_qbits( radius, rij, min_dist, + alpha_min, units, colouring, show_labels, @@ -75,9 +98,25 @@ def draw_qbits( return fig -def _draw_3d(qbits, draw_bonds, radius, rij, min_dist, ax): +def _draw_3d( + qbits, + draw_bonds, + radius, + rij, + min_dist, + alpha_min, + units, + colouring, + show_labels, + ax, +): positions = qbits.positions + ax.set_xlabel("x" + f" ({units})" if units is not None else "x") + ax.set_ylabel("y" + f" ({units})" if units is not None else "y") + ax.set_zlabel("z" + f" ({units})" if units is not None else "z") + ax.figure.subplots_adjust(right=0.85) + if draw_bonds: f_tol = 1.01 # fractional tolerance neighbours = rij <= radius * f_tol @@ -85,7 +124,7 @@ def _draw_3d(qbits, draw_bonds, radius, rij, min_dist, ax): ii, jj = np.where(neighbours) X, Y, Z = positions[ii].T U, V, W = (positions[jj] - positions[ii]).T - alpha = (min_dist / rij[neighbours]) ** 3 + alpha = alpha_min + (1 - alpha_min) * (min_dist / rij[neighbours]) ** 3 ax.quiver( X, @@ -101,8 +140,35 @@ def _draw_3d(qbits, draw_bonds, radius, rij, min_dist, ax): ) x, y, z = positions.T - for r, c in zip(rads, colors): - ax.scatter(x, y, z, s=r**2, color=(0.1, c, 0.5), zorder=1, alpha=0.8) + if colouring is not None: + inds0 = [j == 0 for j in colouring] + inds1 = [j == 1 for j in colouring] + for r, c in zip(rads, colors): + ax.scatter( + x[inds0], + y[inds0], + z[inds0], + s=r**2, + color=(0.1, c, 0.5), + zorder=1, + alpha=0.8, + ) + ax.scatter( + x[inds1], + y[inds1], + z[inds1], + s=r**2, + color=(c, 0.1, 0.5), + zorder=1, + alpha=0.8, + ) + else: + for r, c in zip(rads, colors): + ax.scatter(x, y, z, s=r**2, color=(0.1, c, 0.5), zorder=1, alpha=0.8) + + if show_labels: + for ind in range(qbits.nqbits): + ax.text(x[ind], y[ind], z[ind], s=qbits.labels[ind]) def _draw_2d( @@ -111,6 +177,7 @@ def _draw_2d( radius, rij, min_dist, + alpha_min, units, colouring, show_labels, @@ -136,7 +203,7 @@ def _draw_2d( ] ) for i, j in neighbours: - alpha = (min_dist / rij[i, j]) ** 3 + alpha = alpha_min + (1 - alpha_min) * (min_dist / rij[i, j]) ** 3 ax.plot([x[i], x[j]], [y[i], y[j]], c="gray", alpha=alpha, zorder=-1) if colouring is not None: @@ -158,3 +225,150 @@ def _draw_2d( if show_labels: for ind in range(qbits.nqbits): ax.text(x[ind], y[ind], s=qbits.labels[ind]) + + +def draw_3d_qbits_interactive( + qbits, radius=None, show_labels=False, colouring=None, units=None, alpha_min=0.0 +): + """ + Visualize the positions of a set of qubits as an interactive 3D Plotly figure. + + This function is intended for 3D lattices only. For 2D or 1D qubits use + :func:`draw_qbits` instead. + + Produces a rotatable, zoomable plot that renders as interactive HTML in + Jupyter Book without requiring a live kernel. + + Parameters + ---------- + qbits : qse.Qbits + The Qbits object. Must have ``qbits.dim == 3``. + radius : float | str, optional + A cutoff radius for visualizing bonds. + Pass 'nearest' to use the smallest distance between qubits. + If None, bonds are not drawn. + show_labels : bool, optional + Whether to show qubit labels. Defaults to False. + colouring : list, optional + A list of 0s and 1s assigning each qubit to a sublattice. + 0 → green, 1 → red. Must have the same length as the number of qubits. + units : str, optional + The units of distance, shown on the axis labels. + alpha_min : float, optional + Minimum opacity for bonds. Defaults to 0.0. + """ + import plotly.graph_objects as go + + if colouring is not None: + if len(colouring) != qbits.nqbits: + raise Exception("The length of colouring must equal the number of Qubits.") + colouring = [int(i) for i in colouring] + + positions = qbits.positions + x, y, z = positions.T + + def axis_label(name): + return name + f" ({units})" if units is not None else name + + fig = go.Figure( + layout=go.Layout( + width=800, + height=700, + scene=dict( + xaxis_title=axis_label("x"), + yaxis_title=axis_label("y"), + zaxis_title=axis_label("z"), + ), + ) + ) + + # --- bonds --- + draw_bonds = radius is not None + rij = None + if draw_bonds: + rij = qbits.get_all_distances() + min_dist = rij[np.logical_not(np.eye(qbits.nqbits, dtype=bool))].min() + if radius == "nearest": + radius = min_dist + elif min_dist > radius: + draw_bonds = False + + if draw_bonds: + f_tol = 1.01 + x_lines, y_lines, z_lines = [], [], [] + for i in range(qbits.nqbits - 1): + for j in range(i + 1, qbits.nqbits): + if rij[i, j] <= radius * f_tol: + x_lines += [positions[i, 0], positions[j, 0], None] + y_lines += [positions[i, 1], positions[j, 1], None] + z_lines += [positions[i, 2], positions[j, 2], None] + + fig.add_trace( + go.Scatter3d( + x=x_lines, + y=y_lines, + z=z_lines, + mode="lines", + line=dict(color="gray", width=2), + opacity=alpha_min + (1 - alpha_min) * 0.5, + showlegend=False, + hoverinfo="skip", + ) + ) + + # --- qubits --- + mode = "markers+text" if show_labels else "markers" + labels = list(qbits.labels) if show_labels else None + + def _rgb(r, g, b): + return "rgb({},{},{})".format(int(r * 255), int(g * 255), int(b * 255)) + + if colouring is not None: + inds0 = np.array([c == 0 for c in colouring]) + inds1 = np.array([c == 1 for c in colouring]) + for rad, c in zip(rads, colors): + fig.add_trace( + go.Scatter3d( + x=x[inds0], + y=y[inds0], + z=z[inds0], + mode=mode, + marker=dict(size=rad / 2, color=_rgb(0.1, c, 0.5), opacity=0.8), + text=np.array(labels)[inds0] if labels else None, + showlegend=False, + hoverinfo="skip", + ) + ) + fig.add_trace( + go.Scatter3d( + x=x[inds1], + y=y[inds1], + z=z[inds1], + mode=mode, + marker=dict(size=rad / 2, color=_rgb(c, 0.1, 0.5), opacity=0.8), + text=np.array(labels)[inds1] if labels else None, + showlegend=False, + hoverinfo="skip", + ) + ) + else: + for rad, c in zip(rads, colors): + fig.add_trace( + go.Scatter3d( + x=x, + y=y, + z=z, + mode=mode, + marker=dict(size=rad / 2, color=_rgb(0.1, c, 0.5), opacity=0.8), + text=labels, + showlegend=False, + hoverinfo="skip", + ) + ) + + try: + from IPython.display import HTML, display + + display(HTML(fig.to_html(full_html=False, include_plotlyjs="require"))) + except ImportError: + fig.show() diff --git a/uv.lock b/uv.lock index 8d23dc2..bd343b5 100644 --- a/uv.lock +++ b/uv.lock @@ -2237,6 +2237,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/de/21aa8394f16add8f7427f0a1326ccd2b3a2a8a3245c9252bc5ac034c6155/myst_parser-3.0.1-py3-none-any.whl", hash = "sha256:6457aaa33a5d474aca678b8ead9b3dc298e89c68e67012e73146ea6fd54babf1", size = 83163, upload-time = "2024-04-28T20:22:39.985Z" }, ] +[[package]] +name = "narwhals" +version = "2.22.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/1c/c80cb7719721a44846c6301ef118434bae30a423924bfad3a47f16bdc064/narwhals-2.22.0.tar.gz", hash = "sha256:6486282bb7e4b4ab55963efbd8be1451b764cc4874b74d1fd625eba9dc60b86f", size = 417565, upload-time = "2026-06-01T13:34:36.249Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/b6/e7cdde7b8e90d5dff25b622f95833ef26567ad184c977278b93a1cbd5717/narwhals-2.22.0-py3-none-any.whl", hash = "sha256:1421797ede01789cc1537619dbc3f36f840737240f748fdb24a60a0225fc80be", size = 453815, upload-time = "2026-06-01T13:34:34.127Z" }, +] + [[package]] name = "nbclient" version = "0.10.4" @@ -2698,6 +2707,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" }, ] +[[package]] +name = "plotly" +version = "6.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "narwhals" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/7f/0f100df1172aadf88a929a9dbb902656b0880ba4b960fe5224867159d8f4/plotly-6.7.0.tar.gz", hash = "sha256:45eea0ff27e2a23ccd62776f77eb43aa1ca03df4192b76036e380bb479b892c6", size = 6911286, upload-time = "2026-04-09T20:36:45.738Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/ad/cba91b3bcf04073e4d1655a5c1710ef3f457f56f7d1b79dcc3d72f4dd912/plotly-6.7.0-py3-none-any.whl", hash = "sha256:ac8aca1c25c663a59b5b9140a549264a5badde2e057d79b8c772ae2920e32ff0", size = 9898444, upload-time = "2026-04-09T20:36:39.812Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -4005,7 +4027,7 @@ wheels = [ [[package]] name = "qse" -version = "1.1.14" +version = "1.1.15" source = { editable = "." } dependencies = [ { name = "matplotlib" }, @@ -4031,6 +4053,7 @@ docs = [ { name = "jupyter-book" }, { name = "jupyter-sphinx" }, { name = "myst-nb" }, + { name = "plotly" }, { name = "pulser" }, { name = "qiskit" }, { name = "sphinx-autoapi" }, @@ -4078,6 +4101,7 @@ requires-dist = [ { name = "myqlm", marker = "extra == 'myqlm'" }, { name = "myst-nb", marker = "extra == 'docs'", specifier = ">=1.1.0" }, { name = "numpy" }, + { name = "plotly", marker = "extra == 'docs'" }, { name = "pre-commit", marker = "extra == 'dev'" }, { name = "pulser", marker = "extra == 'dev'" }, { name = "pulser", marker = "extra == 'docs'" },