|
1 | 1 | import json |
2 | 2 |
|
3 | | -from django.contrib.postgres.forms import SimpleArrayField |
4 | 3 | from django.contrib.postgres.validators import ArrayMaxLengthValidator |
5 | 4 | from django.core import checks, exceptions |
6 | | -from django.db.models import DecimalField, Field, Func, Transform, Value |
| 5 | +from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value |
7 | 6 | from django.db.models.fields.mixins import CheckFieldDefaultMixin |
| 7 | +from django.db.models.lookups import In |
8 | 8 | from django.utils.translation import gettext_lazy as _ |
9 | 9 |
|
10 | | -__all__ = ["ArrayField"] |
| 10 | +from django_mongodb.forms import SimpleArrayField |
11 | 11 |
|
| 12 | +from ..utils import prefix_validation_error |
12 | 13 |
|
13 | | -from django.core.exceptions import ValidationError |
14 | | -from django.utils.functional import SimpleLazyObject |
15 | | -from django.utils.text import format_lazy |
| 14 | +__all__ = ["ArrayField"] |
16 | 15 |
|
17 | 16 |
|
18 | 17 | class AttributeSetter: |
19 | 18 | def __init__(self, name, value): |
20 | 19 | setattr(self, name, value) |
21 | 20 |
|
22 | 21 |
|
23 | | -def prefix_validation_error(error, prefix, code, params): |
24 | | - """ |
25 | | - Prefix a validation error message while maintaining the existing |
26 | | - validation data structure. |
27 | | - """ |
28 | | - if error.error_list == [error]: |
29 | | - error_params = error.params or {} |
30 | | - return ValidationError( |
31 | | - # We can't simply concatenate messages since they might require |
32 | | - # their associated parameters to be expressed correctly which |
33 | | - # is not something `format_lazy` does. For example, proxied |
34 | | - # ngettext calls require a count parameter and are converted |
35 | | - # to an empty string if they are missing it. |
36 | | - message=format_lazy( |
37 | | - "{} {}", |
38 | | - SimpleLazyObject(lambda: prefix % params), |
39 | | - SimpleLazyObject(lambda: error.message % error_params), |
40 | | - ), |
41 | | - code=code, |
42 | | - params={**error_params, **params}, |
43 | | - ) |
44 | | - return ValidationError( |
45 | | - [prefix_validation_error(e, prefix, code, params) for e in error.error_list] |
46 | | - ) |
47 | | - |
48 | | - |
49 | 22 | class ArrayField(CheckFieldDefaultMixin, Field): |
50 | 23 | empty_strings_allowed = False |
51 | 24 | default_error_messages = { |
@@ -293,55 +266,44 @@ def _rhs_not_none_values(self, rhs): |
293 | 266 | yield True |
294 | 267 |
|
295 | 268 |
|
296 | | -# @ArrayField.register_lookup |
297 | | -# class ArrayContains(ArrayRHSMixin, lookups.DataContains): |
298 | | -# pass |
299 | | - |
300 | | - |
301 | | -# @ArrayField.register_lookup |
302 | | -# class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy): |
303 | | -# pass |
304 | | - |
305 | | - |
306 | 269 | # @ArrayField.register_lookup |
307 | 270 | # class ArrayExact(ArrayRHSMixin, Exact): |
308 | | -# pass |
| 271 | +# pass |
309 | 272 |
|
310 | 273 |
|
311 | | -# @ArrayField.register_lookup |
312 | | -# class ArrayOverlap(ArrayRHSMixin, lookups.Overlap): |
313 | | -# pass |
| 274 | +@ArrayField.register_lookup |
| 275 | +class ArrayLenTransform(Transform): |
| 276 | + lookup_name = "len" |
| 277 | + output_field = IntegerField() |
314 | 278 |
|
315 | | - |
316 | | -# @ArrayField.register_lookup |
317 | | -# class ArrayLenTransform(Transform): |
318 | | -# lookup_name = "len" |
319 | | -# output_field = IntegerField() |
320 | | - |
321 | | -# def as_sql(self, compiler, connection): |
322 | | -# lhs, params = compiler.compile(self.lhs) |
323 | | -# # Distinguish NULL and empty arrays |
324 | | -# return ( |
325 | | -# "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " |
326 | | -# "coalesce(array_length(%(lhs)s, 1), 0) END" |
327 | | -# ) % {"lhs": lhs}, params * 2 |
| 279 | + def as_sql(self, compiler, connection): |
| 280 | + lhs, params = compiler.compile(self.lhs) |
| 281 | + # Distinguish NULL and empty arrays |
| 282 | + return ( |
| 283 | + ( |
| 284 | + "" # "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " |
| 285 | + # "coalesce(array_length(%(lhs)s, 1), 0) END" |
| 286 | + ) |
| 287 | + % {}, |
| 288 | + params * 2, |
| 289 | + ) |
328 | 290 |
|
329 | 291 |
|
330 | | -# @ArrayField.register_lookup |
331 | | -# class ArrayInLookup(In): |
332 | | -# def get_prep_lookup(self): |
333 | | -# values = super().get_prep_lookup() |
334 | | -# if hasattr(values, "resolve_expression"): |
335 | | -# return values |
336 | | -# # In.process_rhs() expects values to be hashable, so convert lists |
337 | | -# # to tuples. |
338 | | -# prepared_values = [] |
339 | | -# for value in values: |
340 | | -# if hasattr(value, "resolve_expression"): |
341 | | -# prepared_values.append(value) |
342 | | -# else: |
343 | | -# prepared_values.append(tuple(value)) |
344 | | -# return prepared_values |
| 292 | +@ArrayField.register_lookup |
| 293 | +class ArrayInLookup(In): |
| 294 | + def get_prep_lookup(self): |
| 295 | + values = super().get_prep_lookup() |
| 296 | + if hasattr(values, "resolve_expression"): |
| 297 | + return values |
| 298 | + # In.process_rhs() expects values to be hashable, so convert lists |
| 299 | + # to tuples. |
| 300 | + prepared_values = [] |
| 301 | + for value in values: |
| 302 | + if hasattr(value, "resolve_expression"): |
| 303 | + prepared_values.append(value) |
| 304 | + else: |
| 305 | + prepared_values.append(tuple(value)) |
| 306 | + return prepared_values |
345 | 307 |
|
346 | 308 |
|
347 | 309 | class IndexTransform(Transform): |
@@ -388,6 +350,5 @@ def __init__(self, start, end): |
388 | 350 | self.start = start |
389 | 351 | self.end = end |
390 | 352 |
|
391 | | - |
392 | | -# def __call__(self, *args, **kwargs): |
393 | | -# return SliceTransform(self.start, self.end, *args, **kwargs) |
| 353 | + def __call__(self, *args, **kwargs): |
| 354 | + return SliceTransform(self.start, self.end, *args, **kwargs) |
0 commit comments