Refactored model classmethods to manager methos

This commit is contained in:
Marc Aymerich 2015-09-04 10:22:14 +00:00
parent ad37c5fd71
commit 8753b94b8c
28 changed files with 335 additions and 212 deletions

View File

@ -407,8 +407,7 @@ Case
# Don't enforce one contact per account? remove account.email in favour of contacts? # Don't enforce one contact per account? remove account.email in favour of contacts?
#change class LogEntry(models.Model):
action_time = models.DateTimeField(_('action time'), auto_now=True) to auto_now_add
# Model operations on Manager instead of model method
# Mailer: mark as sent # Mailer: mark as sent
# Pending filter filter out orders zero metric from pending

View File

@ -17,7 +17,7 @@ from django.utils.html import escape
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.views.decorators.debug import sensitive_post_parameters from django.views.decorators.debug import sensitive_post_parameters
from ..utils.python import random_ascii from ..utils.python import random_ascii, pairwise
from .forms import AdminPasswordChangeForm from .forms import AdminPasswordChangeForm
#from django.contrib.auth.forms import AdminPasswordChangeForm #from django.contrib.auth.forms import AdminPasswordChangeForm
@ -70,6 +70,31 @@ class AtLeastOneRequiredInlineFormSet(BaseInlineFormSet):
raise forms.ValidationError('At least one item required.') raise forms.ValidationError('At least one item required.')
class EnhaceSearchMixin(object):
def lookup_allowed(self, lookup, value):
""" allows any lookup """
if 'password' in lookup:
return False
return True
def get_search_results(self, request, queryset, search_term):
""" allows to specify field <field_name>:<search_term> """
search_fields = self.get_search_fields(request)
if ':' in search_term:
fields = {field.split('__')[0]: field for field in search_fields}
new_search_term = []
for part in search_term.split():
cur_search_term = ''
for field, term in pairwise(part.split(':')):
if field in fields:
queryset = queryset.filter(**{'%s__icontains' % fields[field]: term})
else:
cur_search_term += ':'.join((field, term))
new_search_term.append(cur_search_term)
search_term = ' '.join(new_search_term)
return super(EnhaceSearchMixin, self).get_search_results(request, queryset, search_term)
class ChangeViewActionsMixin(object): class ChangeViewActionsMixin(object):
""" Makes actions visible on the admin change view page. """ """ Makes actions visible on the admin change view page. """
change_view_actions = () change_view_actions = ()
@ -176,7 +201,11 @@ class ChangeAddFieldsMixin(object):
return super(ChangeAddFieldsMixin, self).get_form(request, obj, **defaults) return super(ChangeAddFieldsMixin, self).get_form(request, obj, **defaults)
class ExtendedModelAdmin(ChangeViewActionsMixin, ChangeAddFieldsMixin, ChangeListDefaultFilter, admin.ModelAdmin): class ExtendedModelAdmin(ChangeViewActionsMixin,
ChangeAddFieldsMixin,
ChangeListDefaultFilter,
EnhaceSearchMixin,
admin.ModelAdmin):
list_prefetch_related = None list_prefetch_related = None
def get_queryset(self, request): def get_queryset(self, request):

View File

@ -13,6 +13,11 @@ from orchestra.utils.mail import send_email_template
from . import settings from . import settings
class AccountManager(auth.UserManager):
def get_main(self):
return self.get(pk=settings.ACCOUNTS_MAIN_PK)
class Account(auth.AbstractBaseUser): class Account(auth.AbstractBaseUser):
# Username max_length determined by LINUX system user/group lentgh: 32 # Username max_length determined by LINUX system user/group lentgh: 32
username = models.CharField(_("username"), max_length=32, unique=True, username = models.CharField(_("username"), max_length=32, unique=True,
@ -39,7 +44,7 @@ class Account(auth.AbstractBaseUser):
"Unselect this instead of deleting accounts.")) "Unselect this instead of deleting accounts."))
date_joined = models.DateTimeField(_("date joined"), default=timezone.now) date_joined = models.DateTimeField(_("date joined"), default=timezone.now)
objects = auth.UserManager() objects = AccountManager()
USERNAME_FIELD = 'username' USERNAME_FIELD = 'username'
REQUIRED_FIELDS = ['email'] REQUIRED_FIELDS = ['email']
@ -55,10 +60,6 @@ class Account(auth.AbstractBaseUser):
def is_staff(self): def is_staff(self):
return self.is_superuser return self.is_superuser
@classmethod
def get_main(cls):
return cls.objects.get(pk=settings.ACCOUNTS_MAIN_PK)
def save(self, active_systemuser=False, *args, **kwargs): def save(self, active_systemuser=False, *args, **kwargs):
created = not self.pk created = not self.pk
if not created: if not created:

View File

@ -17,7 +17,7 @@ def validate_contact(request, bill, error=True):
message = msg.format(relation=_("Related"), account=account, url=url) message = msg.format(relation=_("Related"), account=account, url=url)
send(request, mark_safe(message)) send(request, mark_safe(message))
valid = False valid = False
main = type(bill).account.field.rel.to.get_main() main = type(bill).account.field.rel.to.objects.get_main()
if not hasattr(main, 'billcontact'): if not hasattr(main, 'billcontact'):
account = force_text(main) account = force_text(main)
url = reverse('admin:accounts_account_change', args=(main.id,)) url = reverse('admin:accounts_account_change', args=(main.id,))

View File

@ -130,7 +130,7 @@ class Bill(models.Model):
@cached_property @cached_property
def seller(self): def seller(self):
return Account.get_main().billcontact return Account.objects.get_main().billcontact
@cached_property @cached_property
def buyer(self): def buyer(self):

View File

@ -111,7 +111,7 @@ class Bind9MasterDomainBackend(ServiceController):
from orchestra.contrib.orchestration.manager import router from orchestra.contrib.orchestration.manager import router
operation = Operation(backend, domain, Operation.SAVE) operation = Operation(backend, domain, Operation.SAVE)
servers = [] servers = []
for route in router.get_routes(operation): for route in router.objects.get_for_operation(operation):
servers.append(route.host.get_ip()) servers.append(route.host.get_ip())
return servers return servers

View File

@ -46,7 +46,7 @@ class BatchDomainCreationAdminForm(forms.ModelForm):
if not cleaned_data['account']: if not cleaned_data['account']:
account = None account = None
for name in [cleaned_data['name']] + self.extra_names: for name in [cleaned_data['name']] + self.extra_names:
parent = Domain.get_parent_domain(name) parent = Domain.objects.get_parent(name)
if not parent: if not parent:
# Fake an account to make django validation happy # Fake an account to make django validation happy
account_model = self.fields['account']._queryset.model account_model = self.fields['account']._queryset.model

View File

@ -8,6 +8,21 @@ from orchestra.utils.python import AttrDict
from . import settings, validators, utils from . import settings, validators, utils
class DomainQuerySet(models.QuerySet):
def get_parent(self, name, top=False):
""" get the next domain on the chain """
split = name.split('.')
parent = None
for i in range(1, len(split)-1):
name = '.'.join(split[i:])
domain = Domain.objects.filter(name=name)
if domain:
parent = domain.get()
if not top:
return parent
return parent
class Domain(models.Model): class Domain(models.Model):
name = models.CharField(_("name"), max_length=256, unique=True, name = models.CharField(_("name"), max_length=256, unique=True,
help_text=_("Domain or subdomain name."), help_text=_("Domain or subdomain name."),
@ -51,23 +66,11 @@ class Domain(models.Model):
"servers how long they should keep the data in cache. " "servers how long they should keep the data in cache. "
"The default value is <tt>%s</tt>.") % settings.DOMAINS_DEFAULT_MIN_TTL) "The default value is <tt>%s</tt>.") % settings.DOMAINS_DEFAULT_MIN_TTL)
objects = DomainQuerySet.as_manager()
def __str__(self): def __str__(self):
return self.name return self.name
@classmethod
def get_parent_domain(cls, name, top=False):
""" get the next domain on the chain """
split = name.split('.')
parent = None
for i in range(1, len(split)-1):
name = '.'.join(split[i:])
domain = Domain.objects.filter(name=name)
if domain:
parent = domain.get()
if not top:
return parent
return parent
@property @property
def origin(self): def origin(self):
return self.top or self return self.top or self
@ -122,7 +125,7 @@ class Domain(models.Model):
return self.origin.subdomain_set.all().prefetch_related('records') return self.origin.subdomain_set.all().prefetch_related('records')
def get_parent(self, top=False): def get_parent(self, top=False):
return self.get_parent_domain(self.name, top=top) return type(self).objects.get_parent(self.name, top=top)
def render_zone(self): def render_zone(self):
origin = self.origin origin = self.origin

View File

@ -31,7 +31,7 @@ class DomainSerializer(AccountSerializerMixin, HyperlinkedModelSerializer):
def clean_name(self, attrs, source): def clean_name(self, attrs, source):
""" prevent users creating subdomains of other users domains """ """ prevent users creating subdomains of other users domains """
name = attrs[source] name = attrs[source]
parent = Domain.get_parent_domain(name) parent = Domain.objects.get_parent(name)
if parent and parent.account != self.account: if parent and parent.account != self.account:
raise ValidationError(_("Can not create subdomains of other users domains")) raise ValidationError(_("Can not create subdomains of other users domains"))
return attrs return attrs

View File

@ -4,28 +4,59 @@ from django.core.urlresolvers import reverse, NoReverseMatch
from django.contrib.admin.templatetags.admin_urls import add_preserved_filters from django.contrib.admin.templatetags.admin_urls import add_preserved_filters
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.contrib.admin.utils import unquote from django.contrib.admin.utils import unquote
from django.contrib.admin.templatetags.admin_static import static
from orchestra.admin.utils import admin_link, admin_date from orchestra.admin.utils import admin_link, admin_date
class LogEntryAdmin(admin.ModelAdmin): class LogEntryAdmin(admin.ModelAdmin):
list_display = ( list_display = (
'__str__', 'display_action_time', 'user_link', 'id', 'display_message', 'display_action_time', 'user_link',
)
list_filter = (
'action_flag',
('content_type', admin.RelatedOnlyFieldListFilter),
) )
list_filter = ('action_flag', 'content_type',)
date_hierarchy = 'action_time' date_hierarchy = 'action_time'
search_fields = ('object_repr', 'change_message') search_fields = ('object_repr', 'change_message', 'user__username')
fields = ( fields = (
'user_link', 'content_object_link', 'display_action_time', 'display_action', 'change_message' 'user_link', 'content_object_link', 'display_action_time', 'display_action',
'change_message'
) )
readonly_fields = ( readonly_fields = (
'user_link', 'content_object_link', 'display_action_time', 'display_action', 'user_link', 'content_object_link', 'display_action_time', 'display_action',
) )
actions = None actions = None
list_select_related = ('user', 'content_type')
user_link = admin_link('user') user_link = admin_link('user')
display_action_time = admin_date('action_time', short_description=_("Time")) display_action_time = admin_date('action_time', short_description=_("Time"))
def display_message(self, log):
edit = '<a href="%(url)s"><img src="%(img)s"></img></a>' % {
'url': reverse('admin:admin_logentry_change', args=(log.pk,)),
'img': static('admin/img/icon_changelink.gif'),
}
if log.is_addition():
return _('Added "%(link)s". %(edit)s') % {
'link': self.content_object_link(log),
'edit': edit
}
elif log.is_change():
return _('Changed "%(link)s" - %(changes)s %(edit)s') % {
'link': self.content_object_link(log),
'changes': log.change_message,
'edit': edit,
}
elif log.is_deletion():
return _('Deleted "%(object)s." %(edit)s') % {
'object': log.object_repr,
'edit': edit,
}
display_message.short_description = _("Message")
display_message.admin_order_field = 'action_flag'
display_message.allow_tags = True
def display_action(self, log): def display_action(self, log):
if log.is_addition(): if log.is_addition():
return _("Added") return _("Added")

View File

@ -92,7 +92,7 @@ class Command(BaseCommand):
context = { context = {
'servers': ', '.join(servers), 'servers': ', '.join(servers),
} }
if not confirm("\n\nAre your sure to execute the previous scripts on %(servers)s (yes/no)? " % context) if not confirm("\n\nAre your sure to execute the previous scripts on %(servers)s (yes/no)? " % context):
return return
if not dry: if not dry:
logs = manager.execute(scripts, serialize=serialize, async=True) logs = manager.execute(scripts, serialize=serialize, async=True)

View File

@ -59,7 +59,7 @@ def generate(operations):
for operation in operations: for operation in operations:
logger.debug("Queued %s" % str(operation)) logger.debug("Queued %s" % str(operation))
if operation.routes is None: if operation.routes is None:
operation.routes = router.get_routes(operation, cache=cache) operation.routes = router.objects.get_for_operation(operation, cache=cache)
for route in operation.routes: for route in operation.routes:
# TODO key by action.async # TODO key by action.async
async_action = route.action_is_async(operation.action) async_action = route.action_is_async(operation.action)
@ -196,7 +196,7 @@ def collect(instance, action, **kwargs):
continue continue
operation = Operation(backend_cls, selected, iaction) operation = Operation(backend_cls, selected, iaction)
# Only schedule operations if the router has execution routes # Only schedule operations if the router has execution routes
routes = router.get_routes(operation, cache=route_cache) routes = router.objects.get_for_operation(operation, cache=route_cache)
if routes: if routes:
operation.routes = routes operation.routes = routes
if iaction != Operation.DELETE: if iaction != Operation.DELETE:

View File

@ -144,6 +144,31 @@ class BackendOperation(models.Model):
autodiscover_modules('backends') autodiscover_modules('backends')
class RouteQuerySet(models.QuerySet):
def get_for_operation(self, operation, **kwargs):
cache = kwargs.get('cache', {})
if not cache:
for route in self.filter(is_active=True).select_related('host'):
for action in route.backend_class.get_actions():
key = (route.backend, action)
try:
cache[key].append(route)
except KeyError:
cache[key] = [route]
routes = []
backend_cls = operation.backend
key = (backend_cls.get_name(), operation.action)
try:
target_routes = cache[key]
except KeyError:
pass
else:
for route in target_routes:
if route.matches(operation.instance):
routes.append(route)
return routes
class Route(models.Model): class Route(models.Model):
""" """
Defines the routing that determine in which server a backend is executed Defines the routing that determine in which server a backend is executed
@ -163,6 +188,7 @@ class Route(models.Model):
# default=MethodBackend.get_default()) # default=MethodBackend.get_default())
is_active = models.BooleanField(_("active"), default=True) is_active = models.BooleanField(_("active"), default=True)
objects = RouteQuerySet.as_manager()
class Meta: class Meta:
unique_together = ('backend', 'host') unique_together = ('backend', 'host')
@ -174,30 +200,6 @@ class Route(models.Model):
def backend_class(self): def backend_class(self):
return ServiceBackend.get_backend(self.backend) return ServiceBackend.get_backend(self.backend)
@classmethod
def get_routes(cls, operation, **kwargs):
cache = kwargs.get('cache', {})
if not cache:
for route in cls.objects.filter(is_active=True).select_related('host'):
for action in route.backend_class.get_actions():
key = (route.backend, action)
try:
cache[key].append(route)
except KeyError:
cache[key] = [route]
routes = []
backend_cls = operation.backend
key = (backend_cls.get_name(), operation.action)
try:
target_routes = cache[key]
except KeyError:
pass
else:
for route in target_routes:
if route.matches(operation.instance):
routes.append(route)
return routes
def clean(self): def clean(self):
if not self.match: if not self.match:
self.match = 'True' self.match = 'True'

View File

@ -30,12 +30,12 @@ class RouterTests(BaseTestCase):
route = Route.objects.create(backend=backend, host=self.host, match='True') route = Route.objects.create(backend=backend, host=self.host, match='True')
operation = Operation(backend=TestBackend, instance=route, action='save') operation = Operation(backend=TestBackend, instance=route, action='save')
self.assertEqual(1, len(Route.get_routes(operation))) self.assertEqual(1, len(Route.objects.get_for_operation(operation)))
route = Route.objects.create(backend=backend, host=self.host1, route = Route.objects.create(backend=backend, host=self.host1,
match='route.backend == "%s"' % TestBackend.get_name()) match='route.backend == "%s"' % TestBackend.get_name())
self.assertEqual(2, len(Route.get_routes(operation))) self.assertEqual(2, len(Route.objects.get_for_operation(operation)))
route = Route.objects.create(backend=backend, host=self.host2, route = Route.objects.create(backend=backend, host=self.host2,
match='route.backend == "something else"') match='route.backend == "something else"')
self.assertEqual(2, len(Route.get_routes(operation))) self.assertEqual(2, len(Route.objects.get_for_operation(operation)))

View File

@ -49,7 +49,7 @@ class BilledOrderListFilter(SimpleListFilter):
metric_pks = [] metric_pks = []
prefetch_valid_metrics = Prefetch('metrics', to_attr='valid_metrics', prefetch_valid_metrics = Prefetch('metrics', to_attr='valid_metrics',
queryset=MetricStorage.objects.filter(created_on__gt=F('order__billed_on'), queryset=MetricStorage.objects.filter(created_on__gt=F('order__billed_on'),
created_on__lte=(F('updated_on')-mindelta)) created_on__lte=(F('updated_on')-mindelta)).exclude(value=0)
) )
metric_queryset = queryset.exclude(service__metric='').exclude(billed_on__isnull=True) metric_queryset = queryset.exclude(service__metric='').exclude(billed_on__isnull=True)
for order in metric_queryset.prefetch_related(prefetch_valid_metrics): for order in metric_queryset.prefetch_related(prefetch_valid_metrics):
@ -61,26 +61,36 @@ class BilledOrderListFilter(SimpleListFilter):
break break
return metric_pks return metric_pks
def queryset(self, request, queryset): def filter_pending(self, queryset, reverse=False):
now = timezone.now()
Service = apps.get_model(settings.ORDERS_SERVICE_MODEL) Service = apps.get_model(settings.ORDERS_SERVICE_MODEL)
ignore_qs = Q()
for order in queryset.distinct('service_id').only('service'):
service = order.service
delta = service.handler.get_ignore_delta()
if delta is not None:
ignore_qs = ignore_qs | Q(service_id=service.id, registered_on__gt=now-delta)
ignore_qs = queryset.exclude(ignore_qs)
pending_qs = Q(
Q(pk__in=self.get_pending_metric_pks(ignore_qs)) |
Q(billed_until__isnull=True) | Q(~Q(service__billing_period=Service.NEVER) &
Q(billed_until__lt=now))
)
if reverse:
return queryset.exclude(pending_qs)
else:
return ignore_qs.filter(pending_qs)
def queryset(self, request, queryset):
now = timezone.now()
if self.value() == 'yes': if self.value() == 'yes':
return queryset.filter(billed_until__isnull=False, billed_until__gte=timezone.now()) return queryset.filter(billed_until__isnull=False, billed_until__gte=now)
elif self.value() == 'no': elif self.value() == 'no':
return queryset.exclude(billed_until__isnull=False, billed_until__gte=timezone.now()) return queryset.exclude(billed_until__isnull=False, billed_until__gte=now)
elif self.value() == 'pending': elif self.value() == 'pending':
return queryset.filter( return self.filter_pending(queryset)
Q(pk__in=self.get_pending_metric_pks(queryset)) | Q(
Q(billed_until__isnull=True) | Q(~Q(service__billing_period=Service.NEVER) &
Q(billed_until__lt=timezone.now()))
)
)
elif self.value() == 'not_pending': elif self.value() == 'not_pending':
return queryset.exclude( return self.filter_pending(queryset, reverse=True)
Q(pk__in=self.get_pending_metric_pks(queryset)) | Q(
Q(billed_until__isnull=True) | Q(~Q(service__billing_period=Service.NEVER) &
Q(billed_until__lt=timezone.now()))
)
)
return queryset return queryset

View File

@ -7,6 +7,7 @@ from django.db.models import F, Q, Sum
from django.apps import apps from django.apps import apps
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.utils import timezone from django.utils import timezone
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -104,6 +105,49 @@ class OrderQuerySet(models.QuerySet):
""" return inactive orders """ """ return inactive orders """
return self.filter(cancelled_on__lte=timezone.now(), **kwargs) return self.filter(cancelled_on__lte=timezone.now(), **kwargs)
def update_by_instance(self, instance, service=None, commit=True):
updates = []
if service is None:
Service = apps.get_model(settings.ORDERS_SERVICE_MODEL)
services = Service.objects.filter_by_instance(instance)
else:
services = [service]
for service in services:
orders = Order.objects.by_object(instance, service=service)
orders = orders.select_related('service').active()
if service.handler.matches(instance):
if not orders:
account_id = getattr(instance, 'account_id', instance.pk)
if account_id is None:
# New account workaround -> user.account_id == None
continue
ignore = service.handler.get_ignore(instance)
order = self.model(
content_object=instance,
content_object_repr=str(instance),
service=service,
account_id=account_id,
ignore=ignore)
if commit:
order.save()
updates.append((order, 'created'))
logger.info("CREATED new order id: {id}".format(id=order.id))
else:
if len(orders) > 1:
raise ValueError("A single active order was expected.")
order = orders[0]
updates.append((order, 'updated'))
if commit:
order.update()
elif orders:
if len(orders) > 1:
raise ValueError("A single active order was expected.")
order = orders[0]
order.cancel(commit=commit)
logger.info("CANCELLED order id: {id}".format(id=order.id))
updates.append((order, 'cancelled'))
return updates
class Order(models.Model): class Order(models.Model):
account = models.ForeignKey('accounts.Account', verbose_name=_("account"), account = models.ForeignKey('accounts.Account', verbose_name=_("account"),
@ -132,54 +176,16 @@ class Order(models.Model):
def __str__(self): def __str__(self):
return str(self.service) return str(self.service)
@classmethod
def update_orders(cls, instance, service=None, commit=True):
updates = []
if service is None:
Service = apps.get_model(settings.ORDERS_SERVICE_MODEL)
services = Service.get_services(instance)
else:
services = [service]
for service in services:
orders = Order.objects.by_object(instance, service=service)
orders = orders.select_related('service').active()
if service.handler.matches(instance):
if not orders:
account_id = getattr(instance, 'account_id', instance.pk)
if account_id is None:
# New account workaround -> user.account_id == None
continue
ignore = service.handler.get_ignore(instance)
order = cls(
content_object=instance,
content_object_repr=str(instance),
service=service,
account_id=account_id,
ignore=ignore)
if commit:
order.save()
updates.append((order, 'created'))
logger.info("CREATED new order id: {id}".format(id=order.id))
else:
if len(orders) > 1:
raise ValueError("A single active order was expected.")
order = orders[0]
updates.append((order, 'updated'))
if commit:
order.update()
elif orders:
if len(orders) > 1:
raise ValueError("A single active order was expected.")
order = orders[0]
order.cancel(commit=commit)
logger.info("CANCELLED order id: {id}".format(id=order.id))
updates.append((order, 'cancelled'))
return updates
@classmethod @classmethod
def get_bill_backend(cls): def get_bill_backend(cls):
return import_class(settings.ORDERS_BILLING_BACKEND)() return import_class(settings.ORDERS_BILLING_BACKEND)()
def clean(self):
if self.billed_on < self.registered_on:
raise ValidationError(_("Billed date can not be earlier than registered on."))
if self.billed_until and not self.billed_on:
raise ValidationError(_("Billed on is missing while billed until is being provided."))
def update(self): def update(self):
instance = self.content_object instance = self.content_object
if instance is None: if instance is None:
@ -189,7 +195,7 @@ class Order(models.Model):
if handler.metric: if handler.metric:
metric = handler.get_metric(instance) metric = handler.get_metric(instance)
if metric is not None: if metric is not None:
MetricStorage.store(self, metric) MetricStorage.objects.store(self, metric)
metric = ', metric:{}'.format(metric) metric = ', metric:{}'.format(metric)
description = handler.get_order_description(instance) description = handler.get_order_description(instance)
logger.info("UPDATED order id:{id}, description:{description}{metric}".format( logger.info("UPDATED order id:{id}, description:{description}{metric}".format(
@ -229,6 +235,8 @@ class Order(models.Model):
for metric in self.metrics.filter(created_on__lt=end).order_by('id'): for metric in self.metrics.filter(created_on__lt=end).order_by('id'):
created = metric.created_on created = metric.created_on
if created > ini: if created > ini:
if prev is None:
raise ValueError("Metric storage information is inconsistent.")
cini = prev.created_on cini = prev.created_on
if not result: if not result:
cini = ini cini = ini
@ -259,27 +267,13 @@ class Order(models.Model):
return decimal.Decimal(0) return decimal.Decimal(0)
class MetricStorage(models.Model): class MetricStorageQuerySet(models.QuerySet):
""" Stores metric state for future billing """ def store(self, order, value):
order = models.ForeignKey(Order, verbose_name=_("order"), related_name='metrics')
value = models.DecimalField(_("value"), max_digits=16, decimal_places=2)
created_on = models.DateField(_("created"), auto_now_add=True)
# TODO time field?
updated_on = models.DateTimeField(_("updated"))
class Meta:
get_latest_by = 'id'
def __str__(self):
return str(self.order)
@classmethod
def store(cls, order, value):
now = timezone.now() now = timezone.now()
try: try:
last = cls.objects.filter(order=order).latest() last = self.filter(order=order).latest()
except cls.DoesNotExist: except self.model.DoesNotExist:
cls.objects.create(order=order, value=value, updated_on=now) self.create(order=order, value=value, updated_on=now)
else: else:
# Metric storage has per-day granularity (last value of the day is what counts) # Metric storage has per-day granularity (last value of the day is what counts)
if last.created_on == now.date(): if last.created_on == now.date():
@ -289,7 +283,24 @@ class MetricStorage(models.Model):
else: else:
error = decimal.Decimal(str(settings.ORDERS_METRIC_ERROR)) error = decimal.Decimal(str(settings.ORDERS_METRIC_ERROR))
if value > last.value+error or value < last.value-error: if value > last.value+error or value < last.value-error:
cls.objects.create(order=order, value=value, updated_on=now) self.create(order=order, value=value, updated_on=now)
else: else:
last.updated_on = now last.updated_on = now
last.save(update_fields=['updated_on']) last.save(update_fields=['updated_on'])
class MetricStorage(models.Model):
""" Stores metric state for future billing """
order = models.ForeignKey(Order, verbose_name=_("order"), related_name='metrics')
value = models.DecimalField(_("value"), max_digits=16, decimal_places=2)
created_on = models.DateField(_("created"), auto_now_add=True)
# TODO time field?
updated_on = models.DateTimeField(_("updated"))
objects = MetricStorageQuerySet.as_manager()
class Meta:
get_latest_by = 'id'
def __str__(self):
return str(self.order)

View File

@ -32,8 +32,8 @@ def update_orders(sender, **kwargs):
if sender._meta.app_label not in settings.ORDERS_EXCLUDED_APPS: if sender._meta.app_label not in settings.ORDERS_EXCLUDED_APPS:
instance = kwargs['instance'] instance = kwargs['instance']
if type(instance) in services: if type(instance) in services:
Order.update_orders(instance) Order.objects.update_by_instance(instance)
elif not hasattr(instance, 'account'): elif not hasattr(instance, 'account'):
related = helpers.get_related_object(instance) related = helpers.get_related_object(instance)
if related and related != instance: if related and related != instance:
Order.update_orders(related) Order.objects.update_by_instance(related)

View File

@ -1,5 +1,5 @@
{% extends "admin/base_site.html" %} {% extends "admin/base_site.html" %}
{% load i18n l10n staticfiles admin_urls utils %} {% load i18n l10n staticfiles admin_urls utils orders %}
{% block extrastyle %} {% block extrastyle %}
{{ block.super }} {{ block.super }}
@ -50,7 +50,7 @@ $(document).ready( function () {
{% if not lines %} {% if not lines %}
<table> <table>
<thead> <thead>
<tr><th>{% trans 'Nothing to bill' %}</th></tr> <tr><th>{% trans 'Nothing to bill, all lines have size&times;quantity 0.' %}</th></tr>
</thead> </thead>
</table> </table>
{% else %} {% else %}
@ -67,7 +67,7 @@ $(document).ready( function () {
<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Discount per {{ discount.type }} <br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Discount per {{ discount.type }}
{% endfor %} {% endfor %}
</td> </td>
<td>{{ line.ini | date }} to {{ line.end | date }}</td> <td>{{ line | periodformat }}</td>
<td>{{ line.size | floatformat:"-2" }}&times;{{ line.metric | floatformat:"-2"}}</td> <td>{{ line.size | floatformat:"-2" }}&times;{{ line.metric | floatformat:"-2"}}</td>
<td> <td>
&nbsp;{{ line.subtotal | floatformat:"-2" }} &euro; &nbsp;{{ line.subtotal | floatformat:"-2" }} &euro;

View File

@ -0,0 +1,19 @@
import datetime
from django import template
from django.template.defaultfilters import date
register = template.Library()
@register.filter
def periodformat(line):
if line.ini == line.end:
return date(line.ini)
if line.ini.day == 1 and line.end.day == 1:
end = line.end - datetime.timedelta(days=1)
if line.ini.month == end.month:
return date(line.ini, "N Y")
return '%s to %s' % (date(line.ini, "N Y"), date(end, "N Y"))
return '%s to %s' % (date(line.ini), date(line.end))

View File

@ -164,6 +164,23 @@ class Resource(models.Model):
return tasks.monitor(self.pk) return tasks.monitor(self.pk)
class ResourceDataQuerySet(models.QuerySet):
def get_or_create(self, obj, resource):
ct = ContentType.objects.get_for_model(type(obj))
try:
return self.get(
content_type=ct,
object_id=obj.pk,
resource=resource
), False
except self.model.DoesNotExist:
return self.create(
content_object=obj,
resource=resource,
allocated=resource.default_allocation
), True
class ResourceData(models.Model): class ResourceData(models.Model):
""" Stores computed resource usage and allocation """ """ Stores computed resource usage and allocation """
resource = models.ForeignKey(Resource, related_name='dataset', verbose_name=_("resource")) resource = models.ForeignKey(Resource, related_name='dataset', verbose_name=_("resource"))
@ -177,6 +194,7 @@ class ResourceData(models.Model):
editable=False) editable=False)
content_object = GenericForeignKey() content_object = GenericForeignKey()
objects = ResourceDataQuerySet.as_manager()
class Meta: class Meta:
unique_together = ('resource', 'content_type', 'object_id') unique_together = ('resource', 'content_type', 'object_id')
@ -185,22 +203,6 @@ class ResourceData(models.Model):
def __str__(self): def __str__(self):
return "%s: %s" % (str(self.resource), str(self.content_object)) return "%s: %s" % (str(self.resource), str(self.content_object))
@classmethod
def get_or_create(cls, obj, resource):
ct = ContentType.objects.get_for_model(type(obj))
try:
return cls.objects.get(
content_type=ct,
object_id=obj.pk,
resource=resource
), False
except cls.DoesNotExist:
return cls.objects.create(
content_object=obj,
resource=resource,
allocated=resource.default_allocation
), True
@property @property
def unit(self): def unit(self):
return self.resource.unit return self.resource.unit

View File

@ -43,7 +43,7 @@ def monitor(resource_id, ids=None):
triggers = [] triggers = []
model = resource.content_type.model_class() model = resource.content_type.model_class()
for obj in model.objects.filter(**kwargs): for obj in model.objects.filter(**kwargs):
data, __ = ResourceData.get_or_create(obj, resource) data, __ = ResourceData.objects.get_or_create(obj, resource)
data.update() data.update()
if not resource.disable_trigger: if not resource.disable_trigger:
a = data.used a = data.used

View File

@ -57,7 +57,8 @@ class PHPListService(SoftwareService):
return settings.SAAS_PHPLIST_DB_USER return settings.SAAS_PHPLIST_DB_USER
def get_account(self): def get_account(self):
return self.instance.account.get_main() account_model = self.instance._meta.get_field_by_name('account')[0]
return account_model.objects.get_main()
def validate(self): def validate(self):
super(PHPListService, self).validate() super(PHPListService, self).validate()

View File

@ -247,7 +247,7 @@ class ServiceHandler(plugins.Plugin, metaclass=plugins.PluginMount):
def generate_line(self, order, price, *dates, metric=1, discounts=None, computed=False): def generate_line(self, order, price, *dates, metric=1, discounts=None, computed=False):
""" """
discounts: already applied discounts on price discounts: extra discounts to apply
computed: price = price*size already performed computed: price = price*size already performed
""" """
if len(dates) == 2: if len(dates) == 2:
@ -271,14 +271,17 @@ class ServiceHandler(plugins.Plugin, metaclass=plugins.PluginMount):
'metric': metric, 'metric': metric,
'discounts': [], 'discounts': [],
}) })
discounted = 0
for dtype, dprice in discounts:
self.generate_discount(line, dtype, dprice)
discounted += dprice
# TODO this is needed for all discounts?
subtotal += discounted
if subtotal > price: if subtotal > price:
self.generate_discount(line, self._PLAN, price-subtotal) plan_discount = price-subtotal
self.generate_discount(line, self._PLAN, plan_discount)
subtotal += plan_discount
for dtype, dprice in discounts:
subtotal += dprice
# Prevent compensations to refund money
if dtype == self._COMPENSATION and subtotal < 0:
dprice -= subtotal
if dprice:
self.generate_discount(line, dtype, dprice)
return line return line
def assign_compensations(self, givers, receivers, **options): def assign_compensations(self, givers, receivers, **options):
@ -318,7 +321,7 @@ class ServiceHandler(plugins.Plugin, metaclass=plugins.PluginMount):
cend = comp.end cend = comp.end
if only_beyond: if only_beyond:
cini = beyond cini = beyond
elif not only_beyond: elif only_beyond:
continue continue
dsize += self.get_price_size(cini, cend) dsize += self.get_price_size(cini, cend)
# Extend billing point a little bit to benefit from a substantial discount # Extend billing point a little bit to benefit from a substantial discount
@ -359,8 +362,8 @@ class ServiceHandler(plugins.Plugin, metaclass=plugins.PluginMount):
if intersect: if intersect:
csize += self.get_price_size(intersect.ini, intersect.end) csize += self.get_price_size(intersect.ini, intersect.end)
price = self.get_price(account, metric, position=position, rates=rates) price = self.get_price(account, metric, position=position, rates=rates)
price = price * size
cprice = price * csize cprice = price * csize
price = price * size
if order in priced: if order in priced:
priced[order][0] += price priced[order][0] += price
priced[order][1] += cprice priced[order][1] += cprice
@ -368,23 +371,25 @@ class ServiceHandler(plugins.Plugin, metaclass=plugins.PluginMount):
priced[order] = (price, cprice) priced[order] = (price, cprice)
lines = [] lines = []
for order, prices in priced.items(): for order, prices in priced.items():
discounts = () if hasattr(order, 'new_billed_until'):
# Generate lines and discounts from order.nominal_price discounts = ()
price, cprice = prices # Generate lines and discounts from order.nominal_price
# Compensations > new_billed_until price, cprice = prices
dsize, new_end = self.apply_compensations(order, only_beyond=True) a = order.id
cprice += dsize*price # Compensations > new_billed_until
if cprice: dsize, new_end = self.apply_compensations(order, only_beyond=True)
discounts = ( cprice += dsize*price
(self._COMPENSATION, -cprice), if cprice:
) discounts = (
if new_end: (self._COMPENSATION, -cprice),
size = self.get_price_size(order.new_billed_until, new_end) )
price += price*size if new_end:
order.new_billed_until = new_end size = self.get_price_size(order.new_billed_until, new_end)
line = self.generate_line( price += price*size
order, price, ini, new_end or end, discounts=discounts, computed=True) order.new_billed_until = new_end
lines.append(line) line = self.generate_line(
order, price, ini, new_end or end, discounts=discounts, computed=True)
lines.append(line)
return lines return lines
def bill_registered_or_renew_events(self, account, porders, rates): def bill_registered_or_renew_events(self, account, porders, rates):
@ -503,7 +508,7 @@ class ServiceHandler(plugins.Plugin, metaclass=plugins.PluginMount):
recharges = [] recharges = []
rini = order.billed_on rini = order.billed_on
rend = min(bp, order.billed_until) rend = min(bp, order.billed_until)
bmetric = order.billed_metric bmetric = order.billed_metric or 0
bsize = self.get_price_size(rini, order.billed_until) bsize = self.get_price_size(rini, order.billed_until)
prepay_discount = self.get_price(account, bmetric) * bsize prepay_discount = self.get_price(account, bmetric) * bsize
prepay_discount = round(prepay_discount, 2) prepay_discount = round(prepay_discount, 2)

View File

@ -19,6 +19,18 @@ autodiscover_modules('handlers')
rate_class = import_class(settings.SERVICES_RATE_CLASS) rate_class = import_class(settings.SERVICES_RATE_CLASS)
class ServiceQuerySet(models.QuerySet):
def filter_by_instance(self, instance):
cache = caches.get_request_cache()
ct = ContentType.objects.get_for_model(instance)
key = 'services.Service-%i' % ct.pk
services = cache.get(key)
if services is None:
services = self.filter(content_type=ct, is_active=True)
cache.set(key, services)
return services
class Service(models.Model): class Service(models.Model):
NEVER = '' NEVER = ''
# DAILY = 'DAILY' # DAILY = 'DAILY'
@ -152,20 +164,11 @@ class Service(models.Model):
), ),
default=PREPAY) default=PREPAY)
objects = ServiceQuerySet.as_manager()
def __str__(self): def __str__(self):
return self.description return self.description
@classmethod
def get_services(cls, instance):
cache = caches.get_request_cache()
ct = ContentType.objects.get_for_model(instance)
key = 'services.Service-%i' % ct.pk
services = cache.get(key)
if services is None:
services = cls.objects.filter(content_type=ct, is_active=True)
cache.set(key, services)
return services
@cached_property @cached_property
def handler(self): def handler(self):
""" Accessor of this service handler instance """ """ Accessor of this service handler instance """
@ -251,5 +254,5 @@ class Service(models.Model):
if related_model._meta.model_name != 'account': if related_model._meta.model_name != 'account':
queryset = queryset.select_related('account').all() queryset = queryset.select_related('account').all()
for instance in queryset: for instance in queryset:
updates += order_model.update_orders(instance, service=self, commit=commit) updates += order_model.objects.update_by_instance(instance, service=self, commit=commit)
return updates return updates

View File

@ -22,7 +22,7 @@
metric=& metric=&
nominal_price=28.10& nominal_price=28.10&
tax=21& tax=21&
pricing_period=BILLING_PERIOD& pricing_period=NEVER&
rate_algorithm=orchestra.contrib.plans.ratings.step_price& rate_algorithm=orchestra.contrib.plans.ratings.step_price&
on_cancel=COMPENSATE& on_cancel=COMPENSATE&
payment_style=PREPAY">Mailbox</option> payment_style=PREPAY">Mailbox</option>

View File

@ -67,7 +67,7 @@ class MailboxBillingTest(BaseTestCase):
return self.resource return self.resource
def allocate_disk(self, mailbox, value): def allocate_disk(self, mailbox, value):
data, __ = ResourceData.get_or_create(mailbox, self.resource) data, __ = ResourceData.objects.get_or_create(mailbox, self.resource)
data.allocated = value data.allocated = value
data.save() data.save()

View File

@ -58,7 +58,7 @@ class BaseTrafficBillingTest(BaseTestCase):
def report_traffic(self, account, value): def report_traffic(self, account, value):
MonitorData.objects.create(monitor=FTPTrafficMonitor.get_name(), content_object=account.systemusers.get(), value=value) MonitorData.objects.create(monitor=FTPTrafficMonitor.get_name(), content_object=account.systemusers.get(), value=value)
data, __ = ResourceData.get_or_create(account, self.resource) data, __ = ResourceData.objects.get_or_create(account, self.resource)
data.update() data.update()

View File

@ -3,6 +3,7 @@ import collections
import random import random
import string import string
from io import StringIO from io import StringIO
from itertools import tee
def import_class(cls): def import_class(cls):
@ -118,3 +119,9 @@ def cmp_to_key(mycmp):
return mycmp(self.obj, other.obj) != 0 return mycmp(self.obj, other.obj) != 0
return K return K
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip(a, b)