PGUtil.py 5.0 KB
# coding=utf-8
#author:        4N
#createtime:    2021/5/24
#email:         nheweijun@sina.com
from osgeo import ogr
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker,Session
class PGUtil:

    @classmethod
    def open_pg_data_source(cls,iswrite, uri):
        """
        # 获取PostGIS数据源
        :return:
        """
        db_conn_tuple = cls.get_info_from_sqlachemy_uri(uri)
        fn = "PG: user=%s password=%s host=%s port=%s dbname=%s " % db_conn_tuple
        driver = ogr.GetDriverByName("PostgreSQL")
        if driver is None:
            raise Exception("打开PostgreSQL驱动失败,可能是当前GDAL未支持PostgreSQL驱动!")
        ds = driver.Open(fn, iswrite)
        if ds is None:
            raise Exception("打开数据源失败!")
        return ds

    @classmethod
    def get_info_from_sqlachemy_uri(cls,uri):
        parts = uri.split(":")
        user = parts[1][2:]

        password_list = parts[2].split("@")
        if password_list.__len__() > 2:
            password = "@".join(password_list[:-1])
        else:
            password = parts[2].split("@")[0]
        host = parts[2].split("@")[-1]
        port = parts[3].split("/")[0]
        database = parts[3].split("/")[1]

        return user, password, host, port, database

    @classmethod
    def get_db_session(cls,db_url, autocommit=False) -> Session:
        engine = create_engine(db_url)

        system_session :Session= sessionmaker(bind=engine, autocommit=autocommit)()

        return system_session

    @classmethod
    def get_geo_column(cls,table_name,db_session):
        # 判断空间列
        geom_col_sql = '''
        SELECT a.attname AS field,t.typname AS type
        FROM
        pg_class c,
        pg_attribute a,
        pg_type t
        WHERE
        c.relname = '{}'
        and a.attnum > 0
        and a.attrelid = c.oid
        and a.atttypid = t.oid
        '''.format(table_name)
        geom_col = None
        geom_result = db_session.execute(geom_col_sql)
        for row_proxy in geom_result:
            if row_proxy[1].__eq__("geometry"):
                geom_col = row_proxy[0]

        return geom_col

    @classmethod
    def get_pkey(cls,table_name,db_session):
        # 判断空间列
        pkey_sql = '''
        select pg_attribute.attname as colname from 
        pg_constraint  inner join pg_class 
        on pg_constraint.conrelid = pg_class.oid 
        inner join pg_attribute on pg_attribute.attrelid = pg_class.oid 
        and  pg_attribute.attnum = pg_constraint.conkey[1]
        inner join pg_type on pg_type.oid = pg_attribute.atttypid
        where pg_class.relname = '{}' 
        and pg_constraint.contype='p'
        '''.format(table_name)
        pkey = None
        pkey_result = db_session.execute(pkey_sql)
        for row_proxy in pkey_result:
            pkey = row_proxy[0]
        return pkey


    @classmethod
    def get_table_count(cls,table_name,db_session):
        count_result = db_session.execute('''SELECT reltuples::bigint AS ec FROM pg_class WHERE  oid = 'public."{}"'::regclass'''.format(
            table_name)).fetchone()
        count = count_result[0]
        if count< 1000000:
            count_result = db_session.execute('select count(*) from "{}"'.format(table_name)).fetchone()
            count=count_result[0]
        return count

    @classmethod
    def check_table_privilege(cls,table_name,pri_type,user,pg_ds):
        '''
        通过pg_ds来判断用户是否对表具有权限
        :param table_name:
        :param pri_type:
        :param user:
        :param pg_ds:
        :return:
        '''

        pri = pg_ds.ExecuteSQL("select * from information_schema.table_privileges "
                               "where grantee='{}' and table_name='{}' and privilege_type='{}' "
                               .format(user,table_name,pri_type))

        if pri.GetNextFeature():
            return True
        else:
            return False


    @classmethod
    def get_srid(cls,pg_ds,table_name):
        layer = pg_ds.GetLayerByName(table_name)
        if not layer:
            return None
        srid_sql = '''select st_srid({}) from public."{}" limit 1'''.format(layer.GetGeometryColumn(), layer.GetName())
        srid_layer = pg_ds.ExecuteSQL(srid_sql)
        srid_feature = srid_layer.GetNextFeature()
        if srid_feature:
            if srid_feature.GetField(0):
                return int(srid_feature.GetField(0))
        else:
            return None


    @classmethod
    def check_database_privilege(cls,table_name,pri_type,user,session):
        pass

    @classmethod
    def check_database_privilege(cls,table_name,pri_type,user,session):
        pass

if __name__ == '__main__':
    session:Session = PGUtil.get_db_session("postgresql://postgres:chinadci@172.26.60.100:5432/template1")
    result = session.execute("SELECT datname FROM pg_database;")
    for re in result:
        print(re)