Source code for molecool.visualize

"""
Functions for visualization of molecules
"""
import numpy as np
import matplotlib.pyplot as plt

from .atom_data import atom_colors

[docs]def draw_molecule(coordinates, symbols, draw_bonds=None, save_location=None, dpi=300): """Draw a picture of a molecule. Parameters ---------- coordinates : np.ndarray The coordinates of the molecule. symbols : list The element of each atom in the molecule. draw_bonds : dict, (optional) Bonds to draw. Bonds should be indicated in a dictionary where the indices of bonded atoms are given as the keys of the dictionary. The default is None - no bonds are drawn. save_location : str, (optional) The location to save the image dpi : int, (optional) The resolution of the saved image Returns ------- ax : matplotlib axis The axis of the plot. """ # Create figure fig = plt.figure() ax = fig.add_subplot(111, projection='3d') # Get colors - based on atom name colors = [] for atom in symbols: colors.append(atom_colors[atom]) size = np.array(plt.rcParams['lines.markersize'] ** 2)*200/(len(coordinates)) ax.scatter(coordinates[:,0], coordinates[:,1], coordinates[:,2], marker="o", edgecolors='k', facecolors=colors, alpha=1, s=size) # Draw bonds if draw_bonds: for atoms, bond_length in draw_bonds.items(): atom1 = atoms[0] atom2 = atoms[1] ax.plot(coordinates[[atom1,atom2], 0], coordinates[[atom1,atom2], 1], coordinates[[atom1,atom2], 2], color='k') plt.axis('square') # Save figure if save_location: plt.savefig(save_location, dpi=dpi, graph_min=0, graph_max=2) return ax
[docs]def bond_histogram(bond_list, save_location=None, dpi=300, graph_min=0, graph_max=2): """Draw a histogram of bonds lengths in a molecule. Parameters --------- bond_list : dict Bonds to draw. Bonds should be indicated in a dictionary where the indices of bonded atoms are given as the keys of the dictionary. The default is None - no bonds are drawn. save_location : str, (optional) The location to save the image dpi : int, (optional) The resolution of the saved image """ lengths = [] for atoms, bond_length in bond_list.items(): lengths.append(bond_length) bins = np.linspace(graph_min, graph_max) fig = plt.figure() ax = fig.add_subplot(111) plt.xlabel('Bond Length (angstrom)') plt.ylabel('Number of Bonds') ax.hist(lengths, bins=bins) # Save figure if save_location: plt.savefig(save_location, dpi=dpi) return ax