Source code for kiwi.db.sqlalch

##
## Copyright (C) 2007 Async Open Source
##
## This program is free software; you can redistribute it and/or
## modify it under the terms of the GNU Lesser General Public License
## as published by the Free Software Foundation; either version 2
## of the License, or (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU Lesser General Public License for more details.
##
## You should have received a copy of the GNU Lesser General Public License
## along with this program; if not, write to the Free Software
## Foundation, Inc., or visit: http://www.gnu.org/.
##
##
## Author(s):    Ali Afshar              <aafshar@gmail.com>
##               Johan Dahlin            <jdahlin@async.com.br>
##

"""
SQLAlchemy integration for Kiwi
"""

from sqlalchemy import and_, or_, not_

from kiwi.db.query import NumberQueryState, StringQueryState, \
     DateQueryState, DateIntervalQueryState, QueryExecuter, \
     NumberIntervalQueryState
from kiwi.interfaces import ISearchFilter


[docs]class SQLAlchemyQueryExecuter(QueryExecuter): def __init__(self, session): QueryExecuter.__init__(self) self.session = session self.table = None self._query_callbacks = [] self._filter_query_callbacks = {} self._query = self._default_query self._full_text_indexes = {} # # Public API #
[docs] def set_table(self, table): """ Sets the SQLObject table/object for this executer :param table: a SQLObject subclass """ self.table = table
[docs] def add_query_callback(self, callback): """ Adds a generic query callback :param callback: a callable """ if not callable(callback): raise TypeError self._query_callbacks.append(callback)
[docs] def add_filter_query_callback(self, search_filter, callback): """ Adds a query callback for the filter search_filter :param search_filter: a search filter :param callback: a callable """ if not ISearchFilter.providedBy(search_filter): raise TypeError if not callable(callback): raise TypeError l = self._filter_query_callbacks.setdefault(search_filter, []) l.append(callback)
[docs] def set_query(self, callback): """ Overrides the default query mechanism. :param callback: a callable which till take two arguments: (query, connection) """ if callback is None: callback = self._default_query elif not callable(callback): raise TypeError self._query = callback # # QueryBuilder #
[docs] def search(self, states): """ Execute a search. :param states: """ if self.table is None: raise ValueError("table cannot be None") table = self.table queries = [] for state in states: search_filter = state.filter assert state.filter # Column query if search_filter in self._columns: query = self._construct_state_query( table, state, self._columns[search_filter]) if query: queries.append(query) # Custom per filter/state query. elif search_filter in self._filter_query_callbacks: for callback in self._filter_query_callbacks[search_filter]: query = callback(state) if query: queries.append(query) else: if (self._query == self._default_query and not self._query_callbacks): raise ValueError( "You need to add a search column or a query callback " "for filter %s" % (search_filter)) for callback in self._query_callbacks: query = callback(states) if query: queries.append(query) if queries: query = and_(*queries) else: query = None result = self._query(query) return result # # Private #
def _default_query(self, query): return self.session.query(self.table).select(query) def _construct_state_query(self, table, state, columns): queries = [] for column in columns: query = None table_field = getattr(table.c, column) if isinstance(state, NumberQueryState): query = self._parse_number_state(state, table_field) elif isinstance(state, NumberIntervalQueryState): query = self._parse_number_interval_state(state, table_field) elif isinstance(state, StringQueryState): query = self._parse_string_state(state, table_field) elif isinstance(state, DateQueryState): query = self._parse_date_state(state, table_field) elif isinstance(state, DateIntervalQueryState): query = self._parse_date_interval_state(state, table_field) else: raise NotImplementedError(state.__class__.__name__) if query: queries.append(query) if queries: return or_(*queries) def _parse_number_state(self, state, table_field): if state.value is not None: return table_field == state.value def _parse_number_interval_state(self, state, table_field): queries = [] if state.start: queries.append(table_field >= state.start) if state.end: queries.append(table_field <= state.end) if queries: return and_(*queries) def _parse_string_state(self, state, table_field): if state.text is not None: text = '%%%s%%' % state.text.lower() retval = table_field.like(text) if state.mode == StringQueryState.NOT_CONTAINS: retval = not_(retval) return retval def _parse_date_state(self, state, table_field): if state.date: return table_field == state.date def _parse_date_interval_state(self, state, table_field): queries = [] if state.start: queries.append(table_field >= state.start) if state.end: queries.append(table_field <= state.end) if queries: return and_(*queries)