ShapeData.py 3.0 KB
# coding=utf-8
#author:        4N
#createtime:    2022/3/15
#email:         nheweijun@sina.com


from osgeo import ogr
from osgeo.ogr import *
import math
import os
import copy

class ShapeData:

    driver = ogr.GetDriverByName("ESRI Shapefile")

    def __init__(self,path):

        self.ds: DataSource = self.driver.Open(path, 0)
        if not self.ds:
            raise Exception("打开数据失败!")
        self.layer: Layer = self.ds.GetLayer(0)


    def get_polygons(self):

        polygons = []
        for feature in self.layer:
            f:Feature = feature
            geom : Geometry = copy.deepcopy(f.GetGeometryRef())
            if geom.GetGeometryType() == 3 or geom.GetGeometryType() == -2147483645:
                polygons.append(geom)
            if geom.GetGeometryType() == 6 or geom.GetGeometryType() == -2147483642:
                for i in range(geom.GetGeometryCount()):
                    polygons.append(geom.GetGeometryRef(i))
        return polygons

    def close(self):
        self.ds.Destroy()




    @classmethod
    def create_by_layer(cls,path,layer:Layer):
        data_source: DataSource = cls.driver.CreateDataSource(path)
        data_source.CopyLayer(layer,layer.GetName())
        data_source.Destroy()

    @classmethod
    def create_by_scheme(cls,path,name,sr,geo_type,scheme,features):
        data_source: DataSource = cls.driver.CreateDataSource(path)
        layer :Layer = data_source.CreateLayer(name, sr, geo_type)
        if scheme:
            layer.CreateFields(scheme)
        for feature in features:
            layer.CreateFeature(feature)
        data_source.Destroy()

    @classmethod
    def create_point(cls,path,name,point):
        data_source: DataSource = cls.driver.CreateDataSource(path)
        layer :Layer = data_source.CreateLayer(name, None, ogr.wkbPoint)

        feat_new = ogr.Feature(layer.GetLayerDefn())
        feat_new.SetGeometry(point)
        layer.CreateFeature(feat_new)
        data_source.Destroy()

    @classmethod
    def create_shp_fromwkts(cls,path,name,wkts):

        geo_type = None
        geoms = []
        for wkt in wkts:
            geom : Geometry = ogr.CreateGeometryFromWkt(wkt)
            if geo_type is None:
                geo_type = geom.GetGeometryType()
            geoms.append(geom)

        if os.path.exists(path):

            pre_name = ".".join(path.split(".")[0:-1])
            for bac in ["dbf","prj","cpg","shp","shx","sbn","sbx"]:
                try:
                    os.remove(pre_name+"."+bac)
                except Exception as e:
                    pass

        data_source: DataSource = cls.driver.CreateDataSource(path)
        layer :Layer = data_source.CreateLayer(name, None, geo_type)

        for geom in geoms:
            feat_new = ogr.Feature(layer.GetLayerDefn())
            feat_new.SetGeometry(geom)
            layer.CreateFeature(feat_new)
        data_source.Destroy()