database_register.py 13.0 KB
# coding=utf-8
#author:        4N
#createtime:    2021/3/9
#email:         nheweijun@sina.com
import datetime
from ..models import Database,db,Table,Columns,TableVacuate,DES
from app.models import AESHelper

import uuid
from . import database_test
from osgeo.ogr import DataSource,Layer,FeatureDefn,FieldDefn
from sqlalchemy.orm import Session
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.util.component.ApiTemplate import ApiTemplate
from app.util.component.PGUtil import PGUtil
from app.util.component.SQLUtil import SQLUtil
from app.util.component.GeometryAdapter import GeometryAdapter
from app.util.component.StructurePrint import StructurePrint


class Api(ApiTemplate):
    api_name = "注册数据库"
    def process(self):
        res ={}
        res["result"] = False
        try:
            host = self.para.get("host")
            port = self.para.get("port")
            user = self.para.get("user")
            passwd = self.para.get("passwd")
            database = self.para.get("database")
            encryption = int(self.para.get("encryption","0"))
            if encryption:
                # passwd = DES.decode(passwd)
                passwd = AESHelper.decode(passwd)


            sqlalchemy_uri = "postgresql://{}:{}@{}:{}/{}".format(user,passwd,host,port,database)
            connectsrt = "hostaddr={} port={} dbname='{}' user='{}' password='{}'".format(host,port,database,user,passwd)
    
            #测试连接
            test = database_test.Api().result
    
            if not test["result"]:
                return test
    
            #判断数据库是否存在
            database = db.session.query(Database).filter_by(alias=self.para.get("alias")).one_or_none()
    
            #真实的数据库
            real_database = db.session.query(Database).filter_by(connectstr=DES.encode(connectsrt)).all()
    
            if database:
                res["msg"] = "数据库已存在,请修改别名!"
                return res

            elif real_database:
                res["msg"] = "数据库连接已存在,请修改数据库连接!"
                return res
            elif not self.check_space(sqlalchemy_uri):
                res["msg"] = "数据库不是空间数据库!"
                return res
            else:
                this_time = datetime.datetime.now()
                database_guid = uuid.uuid1().__str__()
    
                db_tuple = PGUtil.get_info_from_sqlachemy_uri(sqlalchemy_uri)
    
                database = Database(guid= database_guid,
                                    name = db_tuple[4],
                                    alias=self.para.get("alias"),
                                    sqlalchemy_uri=DES.encode(sqlalchemy_uri),
                                    description=self.para.get("description"),
                                    creator=self.para.get("creator"),
                                    create_time=this_time,
                                    update_time=this_time,
                                    connectstr=DES.encode(connectsrt))
                db.session.add(database)

                # 将该库中的数据都注册进来
                self.register_table(database)
                db.session.commit()
                res["msg"] = "注册成功!"
                res["result"]=True
                res["data"] = database_guid
        except Exception as e:
            db.session.rollback()
            raise e
        return res


    def register_table(self,database):
        this_time = datetime.datetime.now()

        pg_ds: DataSource = PGUtil.open_pg_data_source(1, DES.decode(database.sqlalchemy_uri))
        db_tuple = PGUtil.get_info_from_sqlachemy_uri(DES.decode(database.sqlalchemy_uri))

        # 注册空间表
        spatial_table_name,tables = self.register_spatial_table(pg_ds, database, this_time,db_tuple)

        #注册普通表
        self.register_common_table(this_time,database,spatial_table_name,db_tuple)
        pg_ds.Destroy()


    def register_spatial_table(self,pg_ds,database,this_time,db_tuple):
        spatial_table_name =[]
        tables=[]
        for i in range(pg_ds.GetLayerCount()):
            layer: Layer = pg_ds.GetLayer(i)
            l_name = layer.GetName()
            try:

                # 只注册public的空间表,其他表空间的表名会有.
                if layer.GetName().__contains__("."):
                    continue

                # 不注册抽稀表
                if layer.GetName().__contains__("_vacuate_"):
                    spatial_table_name.append(layer.GetName())
                    continue

                # 没有权限的表跳过
                if not PGUtil.check_table_privilege(l_name,"SELECT",db_tuple[0],pg_ds):
                    StructurePrint().print("用户{}对表{}没有select权限!".format(db_tuple[0],l_name),"warn")
                    continue

                # 范围统计和数量统计以100w为界限
                query_count_layer: Layer = pg_ds.ExecuteSQL(
                    '''SELECT reltuples::bigint AS ec FROM pg_class WHERE  oid = 'public."{}"'::regclass'''.format(l_name))

                feature_count = query_count_layer.GetFeature(0).GetField("ec")
                # 要素少于100w可以精确统计
                if feature_count < 1000000:
                    feature_count = layer.GetFeatureCount()
                    ext = layer.GetExtent()
                else:
                    query_ext_layer: Layer = pg_ds.ExecuteSQL(
                        "select geometry(ST_EstimatedExtent('public', '{}','{}'))".format(l_name,
                                                                                           layer.GetGeometryColumn()))
                    ext = query_ext_layer.GetExtent()
                if ext[0] < 360:
                    ext = [round(e, 6) for e in ext]
                else:
                    ext = [round(e, 2) for e in ext]
                extent = "{},{},{},{}".format(ext[0], ext[1], ext[2], ext[3])

                table_guid = uuid.uuid1().__str__()

                geom_type = GeometryAdapter.get_geometry_type(layer)


                table = Table(guid=table_guid,
                              database_guid=database.guid,
                              # alias=layer.GetName(),
                              name=layer.GetName(), create_time=this_time, update_time=this_time,
                              table_type=GeometryAdapter.get_table_type(geom_type),
                              extent=extent,
                              feature_count=feature_count
                              )

                db.session.add(table)
                tables.append(table)

                feature_defn: FeatureDefn = layer.GetLayerDefn()

                for i in range(feature_defn.GetFieldCount()):
                    field_defn: FieldDefn = feature_defn.GetFieldDefn(i)
                    field_name = field_defn.GetName()
                    field_alias = field_name if field_defn.GetAlternativeName() is None or field_defn.GetAlternativeName().__eq__(
                        "") else field_defn.GetAlternativeName()
                    column = Columns(guid=uuid.uuid1().__str__(), table_guid=table_guid,
                                     name=field_name, alias=field_alias, create_time=this_time, update_time=this_time)
                    db.session.add(column)

                spatial_table_name.append(layer.GetName())
            except:
                StructurePrint().print("表{}注册失败!".format(l_name), "warn")
                continue
        return spatial_table_name,tables


    def register_common_table(self,this_time,database,spatial_table_name,db_tuple):
        # 注册普通表
        db_session: Session = PGUtil.get_db_session(DES.decode(database.sqlalchemy_uri))

        # 只注册public中的表
        result = db_session.execute(
            "select relname as tabname from pg_class c where  relkind = 'r' and relnamespace=2200 and relname not like 'pg_%' and relname not like 'sql_%' order by relname").fetchall()

        for re in result:
            table_name = re[0]
            if table_name not in spatial_table_name:

                # 没有权限的表跳过
                if not SQLUtil.check_table_privilege(table_name, "SELECT", db_tuple[0], db_session):
                    StructurePrint().print("用户{}对表{}没有select权限!".format(db_tuple[0],table_name), "warn")
                    continue

                table_guid = uuid.uuid1().__str__()

                count = SQLUtil.get_table_count(table_name,db_session)

                table = Table(guid=table_guid,
                              database_guid=database.guid,
                              # alias=layer.GetName(),
                              name=table_name, create_time=this_time, update_time=this_time,
                              table_type=0,
                              feature_count=count
                              )

                db.session.add(table)

                sql = '''
                SELECT
                    a.attnum,
                    a.attname AS field
                FROM
                    pg_class c,
                    pg_attribute a,
                    pg_type t
                WHERE
                    c.relname = '{}'
                    and a.attnum > 0
                    and a.attrelid = c.oid
                    and a.atttypid = t.oid
                ORDER BY a.attnum 
                '''.format(table_name)

                cols = db_session.execute(sql).fetchall()
                for col in cols:
                    column = Columns(guid=uuid.uuid1().__str__(), table_guid=table_guid,
                                     name=col[1], create_time=this_time, update_time=this_time)
                    db.session.add(column)
        db_session.commit()
        db_session.close()


    def regiser_vacuate_table(self,pg_ds,tables,db_tuple):

        # 注册抽稀表
        for i in range(pg_ds.GetLayerCount()):
            layer:Layer = pg_ds.GetLayer(i)
            l_name = layer.GetName()



            if l_name.__contains__("_vacuate_"):

                base_layer_name=l_name.split("_vacuate_")[1]
                level = l_name.split("_")[-2]
                pixel_distance_str: str ="0"
                try:
                    pixel_distance_str :str=  l_name.split("_")[-1]
                    if pixel_distance_str.startswith("0"):
                        pixel_distance_str="0.{}".format(pixel_distance_str)
                except:
                    pass
                base_table = [table for table in tables if table.name.__eq__(base_layer_name)]
                if base_table:
                    base_table = base_table[0]
                    table_vacuate = TableVacuate(guid=uuid.uuid1().__str__(),
                                                 table_guid=base_table.guid,
                                                 level=level,
                                                 name=l_name,
                                                 pixel_distance=float(pixel_distance_str))

                    Table.query.filter_by(guid=base_table.guid).update({"is_vacuate":1})
                    db.session.add(table_vacuate)


    
    def check_space(self,sqlachemy_uri):
        system_session = None
        check = True
        try:
            test_sql = "select st_geometryfromtext('POINT(1 1)')"
            engine = create_engine(sqlachemy_uri)
            system_session = sessionmaker(bind=engine)()
            system_session.execute(test_sql).fetchone()
        except:
            check = False
        finally:
            if system_session:
                system_session.close()
    
        return check
    
    api_doc={
    "tags":["数据库接口"],
    "parameters":[
        {"name": "host",
         "in": "formData",
         "type": "string", "required": "true"},
        {"name": "port",
         "in": "formData",
         "type": "string", "required": "true"},
        {"name": "user",
         "in": "formData",
         "type": "string", "required": "true"},
        {"name": "passwd",
         "in": "formData",
         "type": "string", "required": "true"},
        {"name": "database",
         "in": "formData",
         "type": "string", "required": "true"},
        {"name": "creator",
         "in": "formData",
         "type": "string", "required": "true"},
    
        {"name": "alias",
         "in": "formData",
         "type": "string","description":"数据库别名","required": "true"},
        {"name": "encryption",
         "in": "formData",
         "type": "int", "description": "密码是否加密", "enum": [0, 1]}
    ],
    "responses":{
        200:{
            "schema":{
                "properties":{
                }
            }
            }
        }
}