1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
// Copyright Global Phasing Ltd.
#include "common.h"
#include <nanobind/stl/bind_vector.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include "array.h"
#include "gemmi/read_cif.hpp" // for read_cif_gz
#include "gemmi/smcif.hpp" // for make_small_structure_from_block
#include "gemmi/small.hpp" // for SmallStructure
#include "gemmi/gz.hpp" // for estimate_uncompressed_size
#include "gemmi/interop.hpp" // for atom_to_site, mx_to_sx_structure
#include "gemmi/flat.hpp" // for FlatStructure, FlatAtom
//
using namespace gemmi;
NB_MAKE_OPAQUE(std::vector<SmallStructure::Site>)
NB_MAKE_OPAQUE(std::vector<FlatAtom>)
void add_small(nb::module_& m) {
// from smcif.hpp
m.def("read_small_structure", [](const std::string& path) {
cif::Block block = read_cif_gz(path).sole_block();
return new SmallStructure(make_small_structure_from_block(block));
}, nb::arg("path"), "Reads a small molecule CIF file.");
m.def("make_small_structure_from_block", &make_small_structure_from_block,
nb::arg("block"), "Takes CIF block and returns SmallStructure.");
// and an unrelated function from gz.hpp
m.def("estimate_uncompressed_size", &estimate_uncompressed_size,
nb::arg("path"),
"Returns uncompressed size of a .gz file (not always reliable)");
using gemmi::SmallStructure;
nb::class_<SmallStructure> small_structure(m, "SmallStructure");
nb::class_<SmallStructure::Site>(small_structure, "Site")
.def(nb::init<>())
.def("__init__", [](SmallStructure::Site* site, const Atom& atom, const UnitCell& cell) {
new(site) SmallStructure::Site(gemmi::atom_to_site(atom, cell));
})
.def_rw("label", &SmallStructure::Site::label)
.def_rw("type_symbol", &SmallStructure::Site::type_symbol)
.def_rw("fract", &SmallStructure::Site::fract)
.def_rw("occ", &SmallStructure::Site::occ)
.def_rw("u_iso", &SmallStructure::Site::u_iso)
.def_rw("element", &SmallStructure::Site::element)
.def_rw("charge", &SmallStructure::Site::charge)
.def_rw("disorder_group", &SmallStructure::Site::disorder_group)
.def_rw("aniso", &SmallStructure::Site::aniso)
.def("orth", &SmallStructure::Site::orth)
.def("clone", [](const SmallStructure::Site& self) { return new SmallStructure::Site(self); })
.def("__repr__", [](const SmallStructure::Site& self) {
return "<gemmi.SmallStructure.Site " + self.label + ">";
});
nb::bind_vector<std::vector<SmallStructure::Site>, rv_ri>(small_structure, "SiteList");
using AtomType = SmallStructure::AtomType;
nb::class_<AtomType>(small_structure, "AtomType")
.def_ro("symbol", &AtomType::symbol)
.def_ro("element", &AtomType::element)
.def_rw("dispersion_real", &AtomType::dispersion_real)
.def_rw("dispersion_imag", &AtomType::dispersion_imag)
.def("__repr__", [](const AtomType& self) {
return "<gemmi.SmallStructure.AtomType " + self.symbol + ">";
});
small_structure
.def(nb::init<>())
.def_rw("name", &SmallStructure::name)
.def_rw("cell", &SmallStructure::cell)
.def_ro("spacegroup", &SmallStructure::spacegroup,
nb::rv_policy::reference_internal)
.def_rw("spacegroup_hm", &SmallStructure::spacegroup_hm)
.def_rw("spacegroup_hall", &SmallStructure::spacegroup_hall)
.def_rw("spacegroup_number", &SmallStructure::spacegroup_number)
.def_rw("symops", &SmallStructure::symops)
.def_rw("sites", &SmallStructure::sites)
.def_ro("atom_types", &SmallStructure::atom_types)
.def_rw("wavelength", &SmallStructure::wavelength)
.def("add_site", [](SmallStructure& self, const SmallStructure::Site& site) {
self.sites.push_back(site);
})
.def("determine_and_set_spacegroup", &SmallStructure::determine_and_set_spacegroup,
nb::arg("order"))
.def("check_spacegroup", &SmallStructure::check_spacegroup)
.def("get_atom_type", &SmallStructure::get_atom_type)
.def("get_all_unit_cell_sites", &SmallStructure::get_all_unit_cell_sites)
.def("remove_hydrogens", &SmallStructure::remove_hydrogens)
.def("change_occupancies_to_crystallographic",
&SmallStructure::change_occupancies_to_crystallographic,
nb::arg("max_dist")=0.4)
.def("make_cif_block", &make_cif_block_from_small_structure)
.def("__repr__", [](const SmallStructure& self) {
return "<gemmi.SmallStructure: " + std::string(self.name) + ">";
});
m.def("mx_to_sx_structure", &gemmi::mx_to_sx_structure,
nb::arg("st"), nb::arg("n")=0);
// FlatStructure bindings
nb::class_<FlatStructure>(m, "FlatStructure")
.def(nb::init<const Structure&>(), nb::arg("structure"),
"Create a flat representation of a Structure")
.def("generate_structure", &FlatStructure::generate_structure,
"Reconstructs a Structure from the flat table of atoms")
.def("__len__", [](const FlatStructure& self) { return self.table.size(); })
.def("__repr__", [](const FlatStructure& self) {
return "<gemmi.FlatStructure with " + std::to_string(self.table.size()) + " atoms>";
})
// NumPy-like array properties for atomic data
.def_prop_ro("b_iso", [](FlatStructure& self) {
return vector_member_array(self.table, &FlatAtom::b_iso);
}, nb::rv_policy::reference_internal, "B-factors as numpy array")
.def_prop_ro("occ", [](FlatStructure& self) {
return vector_member_array(self.table, &FlatAtom::occ);
}, nb::rv_policy::reference_internal, "Occupancies as numpy array")
.def_prop_ro("pos", [](FlatStructure& self) {
// Create a view of positions as (N, 3) array
constexpr int64_t stride = static_cast<int64_t>(sizeof(FlatAtom) / sizeof(double));
return nb::ndarray<nb::numpy, double, nb::shape<-1, 3>>(
&(self.table.data()->pos.x),
{self.table.size(), 3},
nb::handle(),
{stride, 1});
}, nb::rv_policy::reference_internal, "Positions as (N, 3) numpy array")
.def_prop_ro("charge", [](FlatStructure& self) {
return vector_member_array(self.table, &FlatAtom::charge);
}, nb::rv_policy::reference_internal, "Charges as numpy array")
.def_prop_ro("model_num", [](FlatStructure& self) {
return vector_member_array(self.table, &FlatAtom::model_num);
}, nb::rv_policy::reference_internal, "Model numbers as numpy array")
.def_prop_ro("selected", [](FlatStructure& self) {
return vector_member_array(self.table, &FlatAtom::selected);
}, nb::rv_policy::reference_internal, "Selection flags as numpy array")
// String fields as S8 (8-byte fixed-width string) numpy arrays
.def_prop_ro("atom_names", [](FlatStructure& self) {
constexpr int64_t stride = static_cast<int64_t>(sizeof(FlatAtom));
return nb::ndarray<nb::numpy, char, nb::shape<-1, 8>>(
self.table.data()->atom_name,
{self.table.size(), 8},
nb::handle(),
{stride, 1});
}, nb::rv_policy::reference_internal, "Atom names as (N, 8) char array")
.def_prop_ro("residue_names", [](FlatStructure& self) {
constexpr int64_t stride = static_cast<int64_t>(sizeof(FlatAtom));
return nb::ndarray<nb::numpy, char, nb::shape<-1, 8>>(
self.table.data()->residue_name,
{self.table.size(), 8},
nb::handle(),
{stride, 1});
}, nb::rv_policy::reference_internal, "Residue names as (N, 8) char array")
.def_prop_ro("chain_ids", [](FlatStructure& self) {
constexpr int64_t stride = static_cast<int64_t>(sizeof(FlatAtom));
return nb::ndarray<nb::numpy, char, nb::shape<-1, 8>>(
self.table.data()->chain_id,
{self.table.size(), 8},
nb::handle(),
{stride, 1});
}, nb::rv_policy::reference_internal, "Chain IDs as (N, 8) char array")
.def_prop_ro("subchains", [](FlatStructure& self) {
constexpr int64_t stride = static_cast<int64_t>(sizeof(FlatAtom));
return nb::ndarray<nb::numpy, char, nb::shape<-1, 8>>(
self.table.data()->subchain,
{self.table.size(), 8},
nb::handle(),
{stride, 1});
}, nb::rv_policy::reference_internal, "Subchain IDs as (N, 8) char array")
.def_prop_ro("entity_ids", [](FlatStructure& self) {
constexpr int64_t stride = static_cast<int64_t>(sizeof(FlatAtom));
return nb::ndarray<nb::numpy, char, nb::shape<-1, 8>>(
self.table.data()->entity_id,
{self.table.size(), 8},
nb::handle(),
{stride, 1});
}, nb::rv_policy::reference_internal, "Entity IDs as (N, 8) char array");
}
|