Skip to content

Commit

Permalink
feat: add method in VectorDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
habibutsu committed Jun 7, 2024
1 parent edf28a3 commit a7c00b5
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
68 changes: 56 additions & 12 deletions gdal_boots/gdal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass
from enum import Enum
from numbers import Number
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast
from uuid import uuid4

import affine
Expand Down Expand Up @@ -1005,6 +1005,17 @@ def __setitem__(self, idx, feature: Feature):
INV_FIELD_TYPES = {v: k for k, v in FIELD_TYPES.items() if k is not dict}


class GeometryType(Enum):
Point = ogr.wkbPoint
LineString = ogr.wkbLineString
Polygon = ogr.wkbPolygon
MultiPoint = ogr.wkbMultiPoint
MultiLineString = ogr.wkbMultiLineString
MultiPolygon = ogr.wkbMultiPolygon
GeometryCollection = ogr.wkbGeometryCollection
Unknown = ogr.wkbUnknown


class Layer:
__slots__ = ("ref_ds", "ref_layer")

Expand All @@ -1028,6 +1039,17 @@ def name(self):
def epsg(self) -> int:
return int(self.ref_layer.GetSpatialRef().GetAuthorityCode(None))

@property
def srs(self) -> osr.SpatialReference:
return self.ref_layer.GetSpatialRef()

@property
def geometry_type(self) -> GeometryType:
"""
returns int (ogr.wkbPoint, ogr.wkbPolygon, ...)
"""
return GeometryType(self.ref_layer.GetGeomType())

def set_epsg(self, epsg: int):
logger.warning("this is not legal way to change epsg")
self.ref_layer.GetSpatialRef().ImportFromEPSG(epsg)
Expand Down Expand Up @@ -1152,15 +1174,6 @@ def size(self):
class VectorDataset:
# https://livebook.manning.com/book/geoprocessing-with-python/chapter-3/1

class GeometryType(Enum):
Point = ogr.wkbPoint
LineString = ogr.wkbLineString
Polygon = ogr.wkbPolygon
MultiPoint = ogr.wkbMultiPoint
MultiLineString = ogr.wkbMultiLineString
MultiPolygon = ogr.wkbMultiPolygon
GeometryCollection = ogr.wkbGeometryCollection

def __init__(self, ds):
self.ds: ogr.DataSource | gdal.Dataset = ds
# self.layers = None
Expand Down Expand Up @@ -1222,6 +1235,7 @@ def to_file(self, filename: str, options: DriverOptions, overwrite=True) -> None

# # field = ogr.FieldDefn('field', ogr.OFTInteger)
# # layer.CreateField(field)

driver: ogr.Driver = ogr.GetDriverByName(options.driver_name)
try:
out_ds: ogr.DataSource = driver.CreateDataSource(filename)
Expand All @@ -1236,7 +1250,7 @@ def to_file(self, filename: str, options: DriverOptions, overwrite=True) -> None
# (for example empty) datasource will not created
driver.DeleteDataSource(filename)
out_ds: ogr.DataSource = driver.CreateDataSource(filename)
else:
elif out_ds is None:
raise RuntimeError(gdal.GetLastErrorMsg())

assert out_ds is not None
Expand Down Expand Up @@ -1277,7 +1291,37 @@ def simplify(self, tolerance=5):
feature.simplify(tolerance)

def union(self, other):
pass
raise NotImplementedError()

def to_epsg(self, out_epsg: int):
output_srs = osr.SpatialReference()
output_srs.ImportFromEPSG(out_epsg)

result_vds = VectorDataset.create()
for layer in self.layers:
input_srs = layer.srs

transform = osr.CoordinateTransformation(input_srs, output_srs)
output_layer = result_vds.add_layer(name=layer.name, geom_type=layer.geometry_type, epsg=out_epsg)
layer_defn: ogr.FeatureDefn = layer.ref_layer.GetLayerDefn()

for i in range(layer_defn.GetFieldCount()):
field_defn = layer_defn.GetFieldDefn(i)
output_layer.ref_layer.CreateField(field_defn)

output_layer_defn = output_layer.ref_layer.GetLayerDefn()
# Reproject and copy the features from the input layer to the output layer
for feature in cast(list[ogr.Feature], layer.ref_layer):
geom: ogr.Geometry = feature.GetGeometryRef()
geom.Transform(transform)

output_feature = ogr.Feature(output_layer_defn)
output_feature.SetGeometry(geom)

for i in range(layer_defn.GetFieldCount()):
output_feature.SetField(layer_defn.GetFieldDefn(i).GetNameRef(), feature.GetField(i))
output_layer.ref_layer.CreateFeature(output_feature)
return result_vds

def __del__(self):
if type(self.ds) is ogr.DataSource:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ def test_open_file(minsk_boundary_geojson):
assert ds.layers[0].features[0]["name:en"] == "Minsk"


def test_to_epsg_gpkg(minsk_boundary_gpkg):
vds = VectorDataset.open(minsk_boundary_gpkg)
vds_3857 = vds.to_epsg(3857)
with tempfile.NamedTemporaryFile(suffix=".gpkg") as fd:
vds_3857.to_file(fd.name, options.GPKG())


def test_to_epsg_geojson(minsk_boundary_geojson):
vds = VectorDataset.open(minsk_boundary_geojson)
vds_3857 = vds.to_epsg(3857)
with tempfile.NamedTemporaryFile(suffix=".geojson") as fd:
vds_3857.to_file(fd.name, options.GPKG())


def test_read_from_bytes(minsk_boundary_gpkg):
with open(minsk_boundary_gpkg, "rb") as fd:
data = fd.read()
Expand Down

0 comments on commit a7c00b5

Please sign in to comment.