|
1 | 1 | from IPy import IP
|
2 | 2 |
|
3 |
| -from django.db import models, connection |
4 |
| -from django.db.models import sql, query |
5 |
| -from django.db.models.query_utils import QueryWrapper |
6 |
| - |
7 |
| -NET_OPERATORS = connection.operators.copy() |
8 |
| - |
9 |
| -for operator in ['contains', 'startswith', 'endswith']: |
10 |
| - NET_OPERATORS[operator] = 'ILIKE %s' |
11 |
| - NET_OPERATORS['i%s' % operator] = 'ILIKE %s' |
12 |
| - |
13 |
| -NET_OPERATORS['iexact'] = NET_OPERATORS['exact'] |
14 |
| -NET_OPERATORS['regex'] = NET_OPERATORS['iregex'] |
15 |
| -NET_OPERATORS['net_contained'] = '<< %s' |
16 |
| -NET_OPERATORS['net_contained_or_equal'] = '<<= %s' |
17 |
| -NET_OPERATORS['net_contains'] = '>> %s' |
18 |
| -NET_OPERATORS['net_contains_or_equals'] = '>>= %s' |
19 |
| - |
20 |
| -NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s'] |
21 |
| - |
22 |
| - |
23 |
| -class NetQuery(sql.Query): |
24 |
| - query_terms = sql.Query.query_terms.copy() |
25 |
| - query_terms.update(NET_OPERATORS) |
26 |
| - |
27 |
| - def add_filter(self, (filter_string, value), *args, **kwargs): |
28 |
| - # IP(...) == '' fails so make sure to force to string while we can |
29 |
| - if isinstance(value, IP): |
30 |
| - value = unicode(value) |
31 |
| - return super(NetQuery, self).add_filter( |
32 |
| - (filter_string, value), *args, **kwargs) |
33 |
| - |
34 |
| - |
35 |
| -class NetWhere(sql.where.WhereNode): |
36 |
| - def make_atom(self, child, qn): |
37 |
| - table_alias, name, db_type, lookup_type, value_annot, params = child |
38 |
| - |
39 |
| - if db_type not in ['inet', 'cidr']: |
40 |
| - return super(NetWhere, self).make_atom(child, qn) |
41 |
| - |
42 |
| - if table_alias: |
43 |
| - field_sql = '%s.%s' % (qn(table_alias), qn(name)) |
44 |
| - else: |
45 |
| - field_sql = qn(name) |
46 |
| - |
47 |
| - if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS: |
48 |
| - if db_type == 'inet': |
49 |
| - field_sql = 'HOST(%s)' % field_sql |
50 |
| - else: |
51 |
| - field_sql = 'TEXT(%s)' % field_sql |
52 |
| - |
53 |
| - if isinstance(params, QueryWrapper): |
54 |
| - extra, params = params.data |
55 |
| - else: |
56 |
| - extra = '' |
57 |
| - |
58 |
| - if lookup_type in NET_OPERATORS: |
59 |
| - return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params) |
60 |
| - elif lookup_type == 'in': |
61 |
| - if not value_annot: |
62 |
| - raise sql.datastructures.EmptyResultSet |
63 |
| - if extra: |
64 |
| - return ('%s IN %s' % (field_sql, extra), params) |
65 |
| - return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params) |
66 |
| - elif lookup_type == 'range': |
67 |
| - return ('%s BETWEEN %%s and %%s' % field_sql, params) |
68 |
| - elif lookup_type == 'isnull': |
69 |
| - return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), params) |
70 |
| - |
71 |
| - raise ValueError('Invalid lookup type "%s"' % lookup_type) |
72 |
| - |
73 |
| - |
74 |
| -class NetManger(models.Manager): |
75 |
| - use_for_related_fields = True |
76 |
| - |
77 |
| - def get_query_set(self): |
78 |
| - q = NetQuery(self.model, connection, NetWhere) |
79 |
| - return query.QuerySet(self.model, q) |
| 3 | +from django.db import models |
| 4 | + |
| 5 | +from netfields.managers import NET_OPERATORS, NET_TEXT_OPERATORS |
| 6 | +from netfields.forms import NetAddressFormField, MACAddressFormField |
| 7 | + |
| 8 | +class _NetAddressField(models.Field): |
| 9 | + empty_strings_allowed = False |
| 10 | + |
| 11 | + def __init__(self, *args, **kwargs): |
| 12 | + kwargs['max_length'] = self.max_length |
| 13 | + super(_NetAddressField, self).__init__(*args, **kwargs) |
| 14 | + |
| 15 | + def to_python(self, value): |
| 16 | + if not value: |
| 17 | + value = None |
| 18 | + |
| 19 | + if value is None: |
| 20 | + return value |
| 21 | + |
| 22 | + return IP(value) |
| 23 | + |
| 24 | + def get_db_prep_value(self, value): |
| 25 | + if value is None: |
| 26 | + return value |
| 27 | + |
| 28 | + return unicode(self.to_python(value)) |
| 29 | + |
| 30 | + def get_db_prep_lookup(self, lookup_type, value): |
| 31 | + if value is None: |
| 32 | + return value |
| 33 | + |
| 34 | + if (lookup_type in NET_OPERATORS and |
| 35 | + NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS): |
| 36 | + return [self.get_db_prep_value(value)] |
| 37 | + |
| 38 | + return super(_NetAddressField, self).get_db_prep_lookup( |
| 39 | + lookup_type, value) |
| 40 | + |
| 41 | + def formfield(self, **kwargs): |
| 42 | + defaults = {'form_class': NetAddressFormField} |
| 43 | + defaults.update(kwargs) |
| 44 | + return super(_NetAddressField, self).formfield(**defaults) |
| 45 | + |
| 46 | + |
| 47 | +class InetAddressField(_NetAddressField): |
| 48 | + description = "PostgreSQL INET field" |
| 49 | + max_length = 39 |
| 50 | + __metaclass__ = models.SubfieldBase |
| 51 | + |
| 52 | + def db_type(self): |
| 53 | + return 'inet' |
| 54 | + |
| 55 | + |
| 56 | +class CidrAddressField(_NetAddressField): |
| 57 | + description = "PostgreSQL CIDR field" |
| 58 | + max_length = 43 |
| 59 | + __metaclass__ = models.SubfieldBase |
| 60 | + |
| 61 | + def db_type(self): |
| 62 | + return 'cidr' |
| 63 | + |
| 64 | + |
| 65 | +class MACAddressField(models.Field): |
| 66 | + description = "PostgreSQL MACADDR field" |
| 67 | + |
| 68 | + def __init__(self, *args, **kwargs): |
| 69 | + kwargs['max_length'] = 17 |
| 70 | + super(MACAddressField, self).__init__(*args, **kwargs) |
| 71 | + |
| 72 | + def db_type(self): |
| 73 | + return 'macaddr' |
| 74 | + |
| 75 | + def formfield(self, **kwargs): |
| 76 | + defaults = {'form_class': MACAddressFormField} |
| 77 | + defaults.update(kwargs) |
| 78 | + return super(MACAddressField, self).formfield(**defaults) |
0 commit comments