1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
from django.contrib.gis.db import models
from django.contrib.gis.db.backends.base.adapter import WKTAdapter
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.geos.geometry import GEOSGeometryBase
from django.contrib.gis.geos.prototypes.io import wkb_r
from django.contrib.gis.measure import Distance
from django.db.backends.mysql.operations import DatabaseOperations
from django.utils.functional import cached_property
class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
name = "mysql"
geom_func_prefix = "ST_"
Adapter = WKTAdapter
@cached_property
def mariadb(self):
return self.connection.mysql_is_mariadb
@cached_property
def mysql(self):
return not self.connection.mysql_is_mariadb
@cached_property
def select(self):
return self.geom_func_prefix + "AsBinary(%s)"
@cached_property
def from_text(self):
return self.geom_func_prefix + "GeomFromText"
@cached_property
def collect(self):
if self.connection.features.supports_collect_aggr:
return self.geom_func_prefix + "Collect"
@cached_property
def gis_operators(self):
operators = {
"bbcontains": SpatialOperator(
func="MBRContains"
), # For consistency w/PostGIS API
"bboverlaps": SpatialOperator(func="MBROverlaps"), # ...
"contained": SpatialOperator(func="MBRWithin"), # ...
"contains": SpatialOperator(func="ST_Contains"),
"crosses": SpatialOperator(func="ST_Crosses"),
"disjoint": SpatialOperator(func="ST_Disjoint"),
"equals": SpatialOperator(func="ST_Equals"),
"exact": SpatialOperator(func="ST_Equals"),
"intersects": SpatialOperator(func="ST_Intersects"),
"overlaps": SpatialOperator(func="ST_Overlaps"),
"same_as": SpatialOperator(func="ST_Equals"),
"touches": SpatialOperator(func="ST_Touches"),
"within": SpatialOperator(func="ST_Within"),
}
if self.connection.mysql_is_mariadb:
operators["relate"] = SpatialOperator(func="ST_Relate")
else:
operators["covers"] = SpatialOperator(func="MBRCovers")
operators["coveredby"] = SpatialOperator(func="MBRCoveredBy")
return operators
@cached_property
def disallowed_aggregates(self):
disallowed_aggregates = [
models.Extent,
models.Extent3D,
models.MakeLine,
models.Union,
]
is_mariadb = self.connection.mysql_is_mariadb
if is_mariadb or self.connection.mysql_version < (8, 0, 24):
disallowed_aggregates.insert(0, models.Collect)
return tuple(disallowed_aggregates)
function_names = {
"FromWKB": "ST_GeomFromWKB",
"FromWKT": "ST_GeomFromText",
}
@cached_property
def unsupported_functions(self):
unsupported = {
"AsGML",
"AsKML",
"AsSVG",
"Azimuth",
"BoundingCircle",
"ClosestPoint",
"ForcePolygonCW",
"GeometryDistance",
"IsEmpty",
"LineLocatePoint",
"MakeValid",
"MemSize",
"Perimeter",
"PointOnSurface",
"Reverse",
"Scale",
"SnapToGrid",
"Transform",
"Translate",
}
if self.connection.mysql_is_mariadb:
unsupported.remove("PointOnSurface")
unsupported.update({"GeoHash", "IsValid"})
return unsupported
def geo_db_type(self, f):
return f.geom_type
def get_distance(self, f, value, lookup_type):
value = value[0]
if isinstance(value, Distance):
if f.geodetic(self.connection):
raise ValueError(
"Only numeric values of degree units are allowed on "
"geodetic distance queries."
)
dist_param = getattr(
value, Distance.unit_attname(f.units_name(self.connection))
)
else:
dist_param = value
return [dist_param]
def get_geometry_converter(self, expression):
read = wkb_r().read
srid = expression.output_field.srid
if srid == -1:
srid = None
geom_class = expression.output_field.geom_class
def converter(value, expression, connection):
if value is not None:
geom = GEOSGeometryBase(read(memoryview(value)), geom_class)
if srid:
geom.srid = srid
return geom
return converter
def spatial_aggregate_name(self, agg_name):
return getattr(self, agg_name.lower())
|