Automatically use a function at insert or selectΒΆ

Sometimes the application wants to apply a function in an insert or in a select. For example, the application might need the geometry with lat/lon coordinates while they are projected in the DB. To avoid having to always tweak the query with a ST_Transform(), it is possible to define a TypeDecorator

 11 from sqlalchemy import Column
 12 from sqlalchemy import Integer
 13 from sqlalchemy import MetaData
 14 from sqlalchemy import func
 15 from sqlalchemy import text
 16 from sqlalchemy.ext.declarative import declarative_base
 17 from sqlalchemy.types import TypeDecorator
 18
 19 from geoalchemy2 import Geometry
 20 from geoalchemy2 import shape
 21
 22 # Tests imports
 23 from tests import test_only_with_dialects
 24
 25 metadata = MetaData()
 26
 27 Base = declarative_base(metadata=metadata)
 28
 29
 30 class TransformedGeometry(TypeDecorator):
 31     """This class is used to insert a ST_Transform() in each insert or select."""
 32     impl = Geometry
 33
 34     cache_ok = True
 35
 36     def __init__(self, db_srid, app_srid, **kwargs):
 37         kwargs["srid"] = db_srid
 38         super().__init__(**kwargs)
 39         self.app_srid = app_srid
 40         self.db_srid = db_srid
 41
 42     def column_expression(self, col):
 43         """The column_expression() method is overridden to set the correct type.
 44
 45         This is needed so that the returned element will also be decorated. In this case we don't
 46         want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
 47         ``app_srid`` arguments.
 48         Without this the SRID of the WKBElement would be wrong.
 49         """
 50         return getattr(func, self.impl.as_binary)(
 51             func.ST_Transform(col, self.app_srid),
 52             type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid)
 53         )
 54
 55     def bind_expression(self, bindvalue):
 56         return func.ST_Transform(
 57             self.impl.bind_expression(bindvalue), self.db_srid,
 58             type_=self,
 59         )
 60
 61
 62 class ThreeDGeometry(TypeDecorator):
 63     """This class is used to insert a ST_Force3D() in each insert."""
 64     impl = Geometry
 65
 66     cache_ok = True
 67
 68     def column_expression(self, col):
 69         """The column_expression() method is overridden to set the correct type.
 70
 71         This is not needed in this example but it is needed if one wants to override other methods
 72         of the TypeDecorator class, like ``process_result_value()`` for example.
 73         """
 74         return getattr(func, self.impl.as_binary)(col, type_=self)
 75
 76     def bind_expression(self, bindvalue):
 77         return func.ST_Force3D(
 78             self.impl.bind_expression(bindvalue),
 79             type=self,
 80         )
 81
 82
 83 class Point(Base):
 84     __tablename__ = "point"
 85     id = Column(Integer, primary_key=True)
 86     raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
 87     geom = Column(
 88         TransformedGeometry(
 89             db_srid=2154, app_srid=4326, geometry_type="POINT"))
 90     three_d_geom = Column(
 91         ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
 92
 93
 94 def check_wkb(wkb, x, y):
 95     pt = shape.to_shape(wkb)
 96     assert round(pt.x, 5) == x
 97     assert round(pt.y, 5) == y
 98
 99
100 @test_only_with_dialects("postgresql")
101 class TestTypeDecorator():
102
103     def _create_one_point(self, session, conn):
104         metadata.drop_all(conn, checkfirst=True)
105         metadata.create_all(conn)
106
107         # Create new point instance
108         p = Point()
109         p.raw_geom = "SRID=4326;POINT(5 45)"
110         p.geom = "SRID=4326;POINT(5 45)"
111         p.three_d_geom = "SRID=4326;POINT(5 45)"  # Insert 2D geometry into 3D column
112
113         # Insert point
114         session.add(p)
115         session.flush()
116         session.expire(p)
117
118         return p.id
119
120     def test_transform(self, session, conn):
121         self._create_one_point(session, conn)
122
123         # Query the point and check the result
124         pt = session.query(Point).one()
125         assert pt.id == 1
126         assert pt.raw_geom.srid == 4326
127         check_wkb(pt.raw_geom, 5, 45)
128
129         assert pt.geom.srid == 4326
130         check_wkb(pt.geom, 5, 45)
131
132         # Check that the data is correct in DB using raw query
133         q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
134         res_q = session.execute(q).fetchone()
135         assert res_q.id == 1
136         assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
137
138         # Compare geom, raw_geom with auto transform and explicit transform
139         pt_trans = session.query(
140             Point,
141             Point.raw_geom,
142             func.ST_Transform(Point.raw_geom, 2154).label("trans"),
143         ).one()
144
145         assert pt_trans[0].id == 1
146
147         assert pt_trans[0].geom.srid == 4326
148         check_wkb(pt_trans[0].geom, 5, 45)
149
150         assert pt_trans[0].raw_geom.srid == 4326
151         check_wkb(pt_trans[0].raw_geom, 5, 45)
152
153         assert pt_trans[1].srid == 4326
154         check_wkb(pt_trans[1], 5, 45)
155
156         assert pt_trans[2].srid == 2154
157         check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
158
159     def test_force_3d(self, session, conn):
160         self._create_one_point(session, conn)
161
162         # Query the point and check the result
163         pt = session.query(Point).one()
164
165         assert pt.id == 1
166         assert pt.three_d_geom.srid == 4326
167         assert pt.three_d_geom.desc.lower() == (
168             '01010000a0e6100000000000000000144000000000008046400000000000000000')

Gallery generated by Sphinx-Gallery