#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# Copyright © 2014 uralbash <>
# Distributed under terms of the MIT license.

SQLAlchemy helpers
import json
import inspect
import itertools

import sqlalchemy
from sqlalchemy import or_, and_

[docs]def unjson(obj): try: return json.loads(obj) except TypeError: pass return obj
[docs]class TableProperty(object): def __init__(self, func): self.func = func def __get__(self, inst, cls): return self.func(cls.__table__)
[docs]class ClassProperty(TableProperty): def __get__(self, inst, cls): return self.func(cls)
[docs]def get_attrname_by_colname(instance, name): """ Get value from SQLAlchemy instance by column name :Parameters: - `instance`: SQLAlchemy model instance. - `name`: Column name :Examples: >>> from sqlalchemy import Column, Integer >>> from sqlalchemy.ext.declarative import declarative_base >>> Base = declarative_base() >>> class MPTTPages(Base): ... __tablename__ = "mptt_pages" ... id = Column(Integer, primary_key=True) ... left = Column("lft", Integer, nullable=False) >>> get_attrname_by_colname(MPTTPages(), 'lft') 'left' """ for attr, column in list(sqlalchemy.inspect(instance.__class__).c.items()): if == name: return attr
[docs]def get_relationship(table): if not table: return None relations = sqlalchemy.inspect(table).relationships return [rel for rel in relations]
[docs]def get_empty_instance(table): """ Return empty instance of model. """ instance_defaults_params = inspect.getargspec(table.__init__).args[1:] # list like ['name', 'group', 'visible'] to dict with empty # value as {'name': None, 'group': None, 'visible': None} init = dict( list(zip(instance_defaults_params, itertools.repeat(None))) ) return table(**init)
[docs]def get_obj(session, table, pk): if not pk: return None pk_list = get_pk(table) def composite_pk_clauses(pk): clauses = [] for item in pk_list: clauses.append(getattr(table, == pk[]) return and_(*clauses) query = session.query(table) if type(pk) is list or type(pk) is tuple: if all(type(item) is dict for item in pk): clauses = [] for item in pk: clauses.append(composite_pk_clauses(item)) return query.filter(or_(*clauses)) return query.filter(pk_list[0].in_(pk)) elif (type(pk) is int or str(pk).isdigit()) and len(pk_list) == 1: return query.get(pk) elif type(pk) is dict: return query.filter(composite_pk_clauses(pk)) return None # pragma: no cover
[docs]def get_pk(obj): """ Return primary key name by model class or instance. :Parameters: - `obj`: SQLAlchemy model instance or class. :Examples: >>> from sqlalchemy import Column, Integer >>> from sqlalchemy.ext.declarative import declarative_base >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = 'users' ... id = Column(Integer, primary_key=True) >>> get_pk(User()) (Column('id', Integer(), table=<users>, primary_key=True, nullable=False),) >>> get_pk(User) (Column('id', Integer(), table=<users>, primary_key=True, nullable=False),) """ if inspect.isclass(obj): pk_list = sqlalchemy.inspect(obj).primary_key else: pk_list = obj.__mapper__.primary_key return pk_list
[docs]def pk_to_list(obj, as_json=False): pk_list = [] primary_keys = get_pk(obj) for item in primary_keys: item_name = get_attrname_by_colname(obj, pk_list.append(item_name) pk_list.append(getattr(obj, item_name)) if as_json: return json.dumps(pk_list) return pk_list
[docs]def pk_to_dict(obj): pk_list = pk_to_list(obj) return pk_list_to_dict(pk_list)
[docs]def pk_list_to_dict(pk_list): """ convert list of multi pk to dict ['id', 1, 'id2', 22, 'foo', 'bar'] -> {'foo': 'bar', 'id2': 22, 'id': 1} """ if pk_list and len(pk_list) % 2 == 0: return dict(zip(pk_list[::2], pk_list[1::2])) return None
[docs]def get_obj_by_request_data(session, table, data): pk = {} table_pk = get_pk(table) if not data: return None for item in table_pk: if in data: pk[] = data[] else: return None try: return get_obj(session, table, pk) except sqlalchemy.orm.exc.NoResultFound: return None
[docs]def columns_by_group(table): if hasattr(table, 'sacrud_detail_col'): return [c for c in table.sacrud_detail_col] #!msg/sqlalchemy/k0MrlMBvAls/XPJqKzMYFQAJ return [ ('', [ (c.key, c._orig_columns[0]) for c in sorted( sqlalchemy.inspection.inspect(table).column_attrs, key=lambda col: col.columns[0]._creation_order) ]) ]
[docs]def get_flat_columns(table): columns = [] if not hasattr(table, 'sacrud_detail_col'): return columns for item in table.sacrud_detail_col: for col in item[1]: columns.append(col) return columns
[docs]def get_columns(obj): if hasattr(obj, '__table__'): return obj.__table__.columns return sqlalchemy.inspect(obj).mapper.columns
[docs]def column_to_attr_name(key, table): for c in sorted( sqlalchemy.inspection.inspect(table).column_attrs, key=lambda col: col.columns[0]._creation_order): if c._orig_columns[0].name == key: return c.key return key
