Source code for stoqlib.domain.test.domaintest

# -*- coding: utf-8 -*-
# vi:si:et:sw=4:sts=4:ts=4

##
## Copyright (C) 2006-2013 Async Open Source <http://www.async.com.br>
## All rights reserved
##
## This program is free software; you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 2 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program; if not, write to the Free Software
## Foundation, Inc., or visit: http://www.gnu.org/.
##
## Author(s): Stoq Team <stoq-devel@async.com.br>
##
""" Base module to be used by all domain test modules"""

import contextlib
import datetime
from decimal import Decimal
import os

import mock
import unittest

from stoqlib.lib.kiwilibrary import library
library  # pylint: disable=W0104

import stoqlib
from stoqlib.database.runtime import (get_current_branch,
                                      new_store,
                                      StoqlibStore)
from stoqlib.database.testsuite import StoqlibTestsuiteTracer
from stoqlib.domain.base import Domain
from stoqlib.domain.exampledata import ExampleCreator
from stoqlib.lib.dateutils import localdate, localdatetime


class FakeStoqConfig:
    def __init__(self, settings):
        self.settings = settings
        self.options = None
        self.flushed = False

    def items(self, name):
        return []

    def get_settings(self):
        return self.settings

    def set_from_options(self, options):
        self.options = options

    def load_settings(self, settings):
        pass

    def get(self, section, value):
        if (section, value) == (u'Database', u'enable_production'):
            return u''

    def flush(self):
        self.flushed = True


class FakeStore(object):
    def close(self):
        pass


class FakeDatabaseSettings:
    def __init__(self, store):
        self.store = store
        self.address = u'invalid'
        self.dbname = u'stoq'
        self.check = False
        self.password = u'password'
        self.username = u'username'
        self.port = 12345

    def check_database_address(self):
        return self.check

    def has_database(self):
        return False

    def get_command_line_arguments(self):
        return ['-d', self.dbname,
                '-p', unicode(self.port),
                '-u', self.username,
                '-w', self.password]

    def get_default_connection(self):
        return FakeStore()

    def create_super_store(self):
        return FakeStore()


[docs]class ReadOnlyStore(StoqlibStore): """Wraps a normal store but doesn't actually modify it, commit/rollback/close etc are no-ops""" def __init__(self, database, real_store): # Intentionally *not* calling StoqlibStore.__init__ since this # creates an additional database connection self.real_store = real_store self.retval = False # Store def add(self, obj): pass def flush(self): pass def get(self, cls, key_id): return self.real_store.get(cls, key_id) # Stoqlib Store def fetch(self, obj): return obj def rollback(self, close=True): pass def commit(self): pass def close(self): pass def __eq__(self, other): return self.real_store == getattr(other, 'real_store', None)
[docs]class FakeNamespace(object): """Commonly used mock objects goes in here""" def __init__(self): self.api = mock.Mock() self.api.get_current_branch = get_current_branch self.DatabaseSettings = FakeDatabaseSettings self.StoqConfig = FakeStoqConfig self.datetime = mock.MagicMock(datetime) self.datetime.datetime.today.return_value = localdatetime(2012, 1, 1) self.datetime.datetime.now.return_value = localdatetime(2012, 1, 1) self.datetime.date.today.return_value = localdate(2012, 1, 1).date() def set_store(self, store): # Since we are per default a class attribute we need to call this # when we get a store database = mock.Mock() rd_store = ReadOnlyStore(database, store) self.api.store = rd_store self.api.new_store.return_value = ReadOnlyStore(database, store) if store is not None: store.readonly = rd_store def set_retval(self, retval): self.api.store.retval = retval
class DomainTest(unittest.TestCase, ExampleCreator): fake = FakeNamespace() def __init__(self, test): unittest.TestCase.__init__(self, test) ExampleCreator.__init__(self) @classmethod def setUpClass(cls): cls.store = new_store() cls.fake.set_store(cls.store) @classmethod def tearDownClass(cls): cls.store.close() cls.fake.set_store(None) def setUp(self): self.set_store(self.store) def tearDown(self): self.store.rollback(close=False) self.clear() def assertNotCalled(self, mocked): self.assertEqual(mocked.call_count, 0) def assertCalledOnceWith(self, mocked, *args, **kwargs): mocked.assert_called_once_with(*args, **kwargs) def assertHasCalls(self, mocked, *args, **kwargs): mocked.assert_has_calls(*args, **kwargs) def get_oficial_plugins_names(self): """Get official plugins names Since pluginmanager is prepared to work with plugins defined on the same directory as stoq repository, this is a helper for getting only the ones defined on stoq repository's themselves. """ base_dir = os.path.dirname(os.path.dirname(stoqlib.__file__)) plugins_dir = os.path.join(base_dir, 'plugins') return set(unicode(d) for d in os.listdir(plugins_dir) if not d.startswith('__init__')) @contextlib.contextmanager def count_tracer(self): """Count the number of statements that are executed during a specific context, this is useful for local performance testing where the number of statements shouldn't increase for a specific operation. For this to behave consistently when running one test or many tests, it will clear common caches before starting, so the number in here will be higher than in the actual application. """ self.store.flush() self.store.invalidate() tracer = StoqlibTestsuiteTracer() tracer.install() yield tracer tracer.remove() @contextlib.contextmanager def sysparam(self, **kwargs): """ Updates a set of system parameters within a context. The values will be reverted when leaving the scope. kwargs contains a dictionary of parameter name->value """ from stoqlib.lib.parameters import sysparam old_values = {} for param, value in kwargs.items(): if type(value) is bool: old_values[param] = sysparam.get_bool(param) sysparam.set_bool(self.store, param, value) elif type(value) is int: old_values[param] = sysparam.get_int(param) sysparam.set_int(self.store, param, value) elif isinstance(value, Domain) or value is None: old_values[param] = sysparam.get_object(self.store, param) sysparam.set_object(self.store, param, value) elif isinstance(value, basestring): old_values[param] = sysparam.get_string(param) sysparam.set_string(self.store, param, value) elif isinstance(value, Decimal): old_values[param] = sysparam.get_decimal(param) sysparam.set_decimal(self.store, param, value) else: raise NotImplementedError(type(value)) try: yield finally: for param, value in old_values.items(): if type(value) is bool: sysparam.set_bool(self.store, param, value) elif type(value) is int: sysparam.set_int(self.store, param, value) elif isinstance(value, Domain) or value is None: sysparam.set_object(self.store, param, value) elif isinstance(value, basestring): sysparam.set_string(self.store, param, value) elif isinstance(value, Decimal): sysparam.set_decimal(self.store, param, value) else: raise NotImplementedError(type(value)) def collect_sale_models(self, sale): models = [sale, sale.group, sale.invoice] models.extend(sale.payments) branch = get_current_branch(self.store) for item in sorted(sale.get_items(), cmp=lambda a, b: cmp(a.sellable.description, b.sellable.description)): models.append(item.sellable) stock_item = item.sellable.product_storable.get_stock_item( branch, batch=item.batch) models.append(stock_item) models.append(item) payments = list(sale.payments) if len(payments): p = payments[0] p.description = p.description.rsplit(u' ', 1)[0] return models