Source code for orsopy.slddb.database

"""
Manage database creation, insertion and access.
"""

import sqlite3

from .comparators import Comparator
from .dbconfig import (DB_MATERIALS_CONVERTERS, DB_MATERIALS_FIELD_DEFAULTS, DB_MATERIALS_FIELDS, DB_MATERIALS_NAME,
                       db_lookup)
from .importers import importers
from .material import Formula, Material


[docs] class SLDDB: """ Database to store material parameters to calculate scattering length densities (SLDs) for neutron and x-ray scattering. """ def __init__(self, dbfile): self.db = sqlite3.connect(dbfile)
[docs] def import_material(self, filename, name=None, commit=True): suffix = filename.rsplit(".", 1)[1] res = None for importer in importers: if importer.suffix == suffix: res = importer(filename) break if res is None: raise IOError("File import failed for %s, no suitable importer found" % filename) if name is None: name = res.name return self.add_material(name, res.formula, commit=commit, **importer(filename))
[docs] def add_material(self, name, formula, commit=True, **data): din = {} for key, value in data.items(): if key not in DB_MATERIALS_FIELDS: raise KeyError("%s is not a valid data field" % key) din[key] = db_lookup[key][1].convert(value) if not ("density" in din or "FU_volume" in din or "SLD_n" in din or ("SLD_x" in din and "E_x" in din)): raise ValueError("Not enough information to determine density") din["name"] = db_lookup["name"][1].convert(name) din["formula"] = db_lookup["formula"][1].convert(formula) c = self.db.cursor() # check if entry already exists qstr = "SELECT * FROM %s WHERE %s" % (DB_MATERIALS_NAME, " AND ".join(["%s=?" % key for key in din.keys()])) c.execute(qstr, tuple(din.values())) if len(c.fetchall()) != 0: raise ValueError("Entry with this data already exists") qstr = "INSERT INTO %s (%s) VALUES (%s)" % ( DB_MATERIALS_NAME, ", ".join(din.keys()), ", ".join(["?" for _ in din.keys()]), ) c.execute(qstr, tuple(din.values())) c.close() if commit: self.db.commit()
[docs] def update_material(self, ID, commit=True, **data): din = self.search_material(ID=ID, filter_invalid=False)[0] din.update(data) del din["ID"] del din["updated"] del din["validated"] del din["validated_by"] for key, value in din.items(): if key not in DB_MATERIALS_FIELDS: raise KeyError("%s is not a valid data field" % key) if value is None: continue din[key] = db_lookup[key][1].convert(value) if not any([din.get(name, None) is not None for name in ["density", "FU_volume", "SLD_n", "SLD_x"]]): raise ValueError("Not enough information to determine density") c = self.db.cursor() qstr = "UPDATE %s SET %s,updated = CURRENT_TIMESTAMP,validated = NULL, validated_by = NULL WHERE ID==?" % ( DB_MATERIALS_NAME, ", ".join(["%s = ?" % key for key in din.keys()]), ) c.execute(qstr, tuple(din.values()) + (ID,)) c.close() if commit: self.db.commit()
[docs] def search_material(self, join_and=True, serializable=False, filter_invalid=True, limit=100, offset=0, **data): for key, value in data.items(): if key not in DB_MATERIALS_FIELDS: raise KeyError("%s is not a valid data field" % key) if len(data) == 0: sstr = "SELECT * FROM %s" % DB_MATERIALS_NAME if filter_invalid: sstr += " WHERE invalid IS NULL" qstr = "" qlst = [] ustr = "" else: sstr = "SELECT * FROM %s WHERE " % DB_MATERIALS_NAME if filter_invalid: sstr += "invalid IS NULL AND " ustr = "UPDATE %s SET accessed = accessed + 1 WHERE " % DB_MATERIALS_NAME qstr = "" qlst = [] for key, value in data.items(): if isinstance(value, Comparator): # user has supplied a comparator instead of a value cmp: Comparator = value cmp.key = key else: # use comparator for specific validator cmp: Comparator = db_lookup[key][1].comparator(value, key) qstr += cmp.query_string() qlst_add = cmp.query_args() qlst += qlst_add if len(qlst_add) > 0: if join_and: qstr += " AND " else: qstr += " OR " qstr = qstr[:-5] c = self.db.cursor() c.execute( sstr + qstr + " ORDER BY validated DESC, selected DESC, accessed DESC LIMIT %i,%i" % (offset, limit), qlst ) results = c.fetchall() keys = [key for key, *ignore in c.description] # update access counter c.execute(ustr + qstr, qlst) c.close() self.db.commit() # convert values output = [] if serializable: for row in results: rowdict = {key: db_lookup[key][1].revert_serializable(value) for key, value in zip(keys, row)} output.append(rowdict) else: for row in results: rowdict = {key: db_lookup[key][1].revert(value) for key, value in zip(keys, row)} output.append(rowdict) return output
[docs] def count_material(self, join_and=True, filter_invalid=True, **data): for key, value in data.items(): if key not in DB_MATERIALS_FIELDS: raise KeyError("%s is not a valid data field" % key) if len(data) == 0: sstr = "SELECT COUNT(*) FROM %s" % DB_MATERIALS_NAME if filter_invalid: sstr += " WHERE invalid IS NULL" qstr = "" qlst = [] else: sstr = "SELECT COUNT(*) FROM %s WHERE " % DB_MATERIALS_NAME if filter_invalid: sstr += "invalid IS NULL AND " qstr = "" qlst = [] for key, value in data.items(): if isinstance(value, Comparator): # user has supplied a comparator instead of a value cmp: Comparator = value cmp.key = key else: # use comparator for specific validator cmp: Comparator = db_lookup[key][1].comparator(value, key) qstr += cmp.query_string() qlst_add = cmp.query_args() qlst += qlst_add if len(qlst_add) > 0: if join_and: qstr += " AND " else: qstr += " OR " qstr = qstr[:-5] c = self.db.cursor() c.execute(sstr + qstr, qlst) result = c.fetchone() # update access counter c.close() self.db.commit() return result[0]
[docs] def select_material(self, result) -> Material: # generate Material object from database entry and increment selection counter formula = Formula(result["formula"]) if result["density"]: fu_volume = None else: fu_volume = result["FU_volume"] extra_data = {} if result["invalid"] is not None: extra_data["WARNING"] = ( "This entry has been invalidated by ORSO on %s, " "please contact %s for more information." % (result["invalid"], result["invalid_by"]) ) extra_data["ID"] = int(result["ID"]) extra_data["ORSO_validated"] = result["validated"] is not None extra_data["reference"] = result.get("reference", "") extra_data["doi"] = result.get("doi", "") extra_data["description"] = result.get("description", "") m = Material( formula, dens=result["density"], fu_volume=fu_volume, rho_n=result["SLD_n"], xsld=result["SLD_x"], xE=result["E_x"], mu=result["mu"], ID=result["ID"], name=result["name"], extra_data=extra_data, ) ustr = "UPDATE %s SET selected = selected + 1 WHERE ID == ?" % DB_MATERIALS_NAME c = self.db.cursor() c.execute(ustr, (result["ID"],)) c.close() self.db.commit() return m
[docs] def validate_material(self, ID, user): ustr = ( "UPDATE %s SET validated = CURRENT_TIMESTAMP, validated_by = ?," " invalid = NULL, invalid_by = NULL WHERE ID == ?" % DB_MATERIALS_NAME ) c = self.db.cursor() c.execute(ustr, (user, ID,)) c.close() self.db.commit()
[docs] def invalidate_material(self, ID, user): ustr = ( "UPDATE %s SET invalid = CURRENT_TIMESTAMP, invalid_by = ?, " " validated = NULL, validated_by = NULL WHERE ID == ?" % DB_MATERIALS_NAME ) c = self.db.cursor() c.execute(ustr, (user, ID,)) c.close() self.db.commit()
[docs] def create_table(self): c = self.db.cursor() name_type = [ "%s %s %s" % (fi, ci.sql_type, (di is not None) and "DEFAULT %s" % di or "") for fi, ci, di in zip(DB_MATERIALS_FIELDS, DB_MATERIALS_CONVERTERS, DB_MATERIALS_FIELD_DEFAULTS) ] qstr = "CREATE TABLE %s (%s)" % (DB_MATERIALS_NAME, ", ".join(name_type)) c.execute(qstr) c.close() self.db.commit()
[docs] def create_database(self): self.create_table() self.db.commit()
[docs] def update_fields(self): # add columns not currently available c = self.db.cursor() c.execute("SELECT * FROM %s LIMIT 1" % DB_MATERIALS_NAME) _ = c.fetchall() fields = [col[0] for col in c.description] if len(fields) == len(DB_MATERIALS_FIELDS) and DB_MATERIALS_FIELDS == fields[: len(DB_MATERIALS_FIELDS)]: return if DB_MATERIALS_FIELDS[: len(fields)] != fields: # need to reorder and/or add/remove columns of the database, requires copy of table name_type = [ "%s %s %s" % (fi, ci.sql_type, (di is not None) and "DEFAULT %s" % di or "") for fi, ci, di in zip(DB_MATERIALS_FIELDS, DB_MATERIALS_CONVERTERS, DB_MATERIALS_FIELD_DEFAULTS) ] qstr = "CREATE TABLE tmp_table (%s)" % (", ".join(name_type)) c.execute(qstr) jf = [field for field in fields if field in DB_MATERIALS_FIELDS] qstr = "INSERT INTO tmp_table (%s) SELECT %s FROM %s" % (",".join(jf), ",".join(jf), DB_MATERIALS_NAME) c.execute(qstr) c.execute("DROP TABLE %s" % DB_MATERIALS_NAME) c.execute("ALTER TABLE tmp_table RENAME TO %s" % DB_MATERIALS_NAME) c.close() self.db.commit() return # append new columns start = len(fields) name_type = [ "%s %s %s" % (fi, ci.sql_type, (di is not None) and "DEFAULT %s" % di or "") for fi, ci, di in zip( DB_MATERIALS_FIELDS[start:], DB_MATERIALS_CONVERTERS[start:], DB_MATERIALS_FIELD_DEFAULTS[start:] ) ] c.execute("ALTER TABLE %s ADD %s" % (DB_MATERIALS_NAME, ", ".join(name_type))) c.close() self.db.commit()
[docs] def backup(self, filename): # make a copy of the open database out = sqlite3.connect(filename) with out: self.db.backup(out) out.close()
def __del__(self): self.db.close()