PyJeeves/pyjeeves/models/abc.py

185 lines
6 KiB
Python

"""
Define an Abstract Base Class (ABC) for models
"""
from decimal import Decimal
from datetime import datetime
from sqlalchemy.sql.expression import and_
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.exc import OperationalError
from sqlalchemy.schema import MetaData, Column
from sqlalchemy.types import Integer
from sqlalchemy.orm.collections import InstrumentedList
from sqlservice import ModelBase, as_declarative
from pyjeeves import logging
from . import db
logger = logging.getLogger("PyJeeves." + __name__)
logger.info("Reading Jeeves DB structure")
meta = MetaData()
try:
meta.reflect(bind=db.raw.connection(),
only=['ar', 'ars', 'xae', 'xare', 'fr', 'kus', 'x1k',
'oh', 'orp', 'lp', 'vg', 'xp', 'xm', 'prh', 'prl',
'kp', 'kpw', 'cr', 'X4', 'xw'])
except OperationalError as e:
logger.error("Failed to read Jeeves DB structure")
raise e
@as_declarative(metadata=meta)
class RawBaseModel(ModelBase):
""" Generalize __init__, __repr__ and to_json
Based on the models columns , ForetagKod=1"""
__to_dict_filter__ = []
__to_dict_only__ = ()
__column_map__ = {}
__reversed_column_map__ = lambda self: {v: k for k, v in self.__column_map__.items()} # noqa
__table_args__ = {
'extend_existing': True
}
__dict_args__ = {
'adapters': {
datetime: lambda value, col, *_: value.strftime('%Y-%m-%d %H:%M'),
Decimal: lambda value, col, *_: float(value) # "{:.2f}".format(value)
}
}
ForetagKod = Column(Integer, primary_key=True)
def __init__(self, data=None, **kargs):
if data:
data = self._map_keys(data)
self.update(data, **kargs)
# super(RawBaseModel, self).__init__(data=None, **kargs)
@classmethod
def _base_filters(self, obj, filters=and_()):
# This method provides base filtering, additional filtering can be done in subclasses
# Add this method to your model if you want more filtering, otherwise leave it out
# import and_ from sqlalchemy package
# this is a base filter for ALL queries
return and_(
obj.ForetagKod == 1,
filters
)
def _map_columns(self, key):
if key in self.__column_map__:
return self.__column_map__[key]
return key
def _map_keys(self, data={}):
rv = {}
for key, value in self.__reversed_column_map__().items():
if key in data:
rv[value] = data[key]
for key, value in data.items():
if hasattr(self, key):
if key in self.relationships().keys():
rv[key] = self._map_relationship_keys(key, value)
else:
rv[key] = value
return rv
def _map_relationship_keys(self, field, value):
"""Get model relationships fields value. Almost a copy from SQLService ModelBase"""
relation_attr = getattr(self.__class__, field)
uselist = relation_attr.property.uselist
relation_class = relation_attr.property.mapper.class_
if uselist:
if not isinstance(value, (list, tuple)): # pragma: no cover
value = [value]
# Convert each value instance to relationship class.
value = [relation_class(val) if not isinstance(val, relation_class)
else val
for val in value]
elif value and isinstance(value, dict):
# Convert single value object to relationship class.
value = relation_class(value)
elif not value and isinstance(value, dict):
# If value is {} and we're trying to update a relationship
# attribute, then we need to set to None to nullify relationship
# value.
value = None
return value
def descriptors_to_dict(self):
"""Return a ``dict`` that maps data loaded in :attr:`__dict__` to this
model's descriptors. The data contained in :attr:`__dict__` represents
the model's state that has been loaded from the database. Accessing
values in :attr:`__dict__` will prevent SQLAlchemy from issuing
database queries for any ORM data that hasn't been loaded from the
database already.
Note:
The ``dict`` returned will contain model instances for any
relationship data that is loaded. To get a ``dict`` containing all
non-ORM objects, use :meth:`to_dict`.
Returns:
dict
"""
descriptors = self.descriptors()
return { # Expose hybrid_property extension
**{key: getattr(self, key) for key in descriptors.keys()
if isinstance(descriptors.get(key), hybrid_property)},
# and return all items included in descriptors
**{key: value for key, value in self.__dict__.items()
if key in descriptors}}
def to_dict(self):
rv = super().to_dict()
if self.__to_dict_only__:
return {
self._map_columns(key): rv[key]
for key in rv
if key in self.__to_dict_only__
}
for _filter in self.__to_dict_filter__:
rv.pop(_filter)
return rv
def from_dict(self, data={}):
for key, value in self.__reversed_column_map__().items():
if key in data:
self[value] = data[key]
for key, value in data.items():
if hasattr(self, key):
if isinstance(self[key], InstrumentedList):
pass
else:
self[key] = value
return self
def merge(self):
db.raw_session.merge(self)
return self
def commit(self):
db.raw_session.commit()
def save(self):
db.raw_session.add(self)
db.raw_session.commit()
return self
def delete(self):
db.raw_session.delete(self)
db.raw_session.commit()