image_register.py 7.8 KB
# coding=utf-8
#author:        4N
#createtime:    2021/7/19
#email:         nheweijun@sina.com


from osgeo import gdal,osr
from osgeo.gdal import Dataset,Band
from app.util.component.ApiTemplate import ApiTemplate
from app.modules.service.image.util.ThriftConnect import ThriftConnect
import json
from ..models import Image
import datetime
from app.models import  db
import uuid
import os
from ..models import ImageTag
from .util.ImageType import ImageType

class Api(ApiTemplate):

    api_name = "注册影像数据"

    def process(self):

        #可以注册一个目录
        #返回结果
        res = {}

        try:
            data_server = self.para.get("data_server")
            paths = json.loads(self.para.get("paths"))
            tag_guids = self.para.get("tag_guids")
            if tag_guids:
                tags = db.session.query(ImageTag).filter(ImageTag.guid.in_(tag_guids.split(","))).all()
            else:
                tags = []


            #注册某影像
            infos = []

            if data_server.__eq__("本地服务器"):

                image_paths = list(self.recur_paths_local(paths))

                for image_info in image_paths:

                    image: Dataset = gdal.Open(image_info["path"], 0)
                    geo = image.GetGeoTransform()

                    origin = osr.SpatialReference()
                    origin.ImportFromWkt(image.GetProjection())

                    authority_code = origin.GetAuthorityCode(None)
                    band_count = image.RasterCount
                    band: Band = image.GetRasterBand(1)
                    count = band.GetOverviewCount()
                    nodatavalue = band.GetNoDataValue()
                    left_top = (geo[0], geo[3])

                    right_buttom = (geo[0] + geo[1] * image.RasterXSize, geo[3] + geo[5] * image.RasterYSize)


                    origin_extent = [left_top[0], right_buttom[1], right_buttom[0], left_top[1]]


                    info = {"band_count": band_count,
                            "band_view":"[1,2,3]" if band_count>=3 else "[1,1,1]",
                            "overview_count": count,
                            "path":image_info["path"],
                            "xy_size": [image.RasterXSize, image.RasterYSize],
                            "origin_extent": origin_extent,
                            "null_value": nodatavalue,
                            "size":os.path.getsize(image_info["path"]),
                            "crs_wkt": image.GetProjection(),
                            "crs": authority_code,
                            "crs_proj4": origin.ExportToProj4(),
                            "cell_x_size": geo[1],
                            "cell_y_size": geo[5]
                            }

                    infos.append(info)
                    del image

            #分布式下,从thrift获取信息
            else:
                thrift_connect = ThriftConnect(data_server)
                image_paths = list(self.recur_paths_remote(paths,thrift_connect.client))
                for image_info in image_paths:
                    infos.append(json.loads(thrift_connect.client.getInfo(image_info["path"])))
                thrift_connect.close()

            this_time = datetime.datetime.now()


            for info in infos:

                exist_image = Image.query.filter_by(path=os.path.normpath(info.get("path")),                                                    size=info.get("size")).one_or_none()
                if exist_image:
                    if exist_image.server.__contains__(data_server):
                        pass
                    else:
                        Image.query.filter_by(path=os.path.normpath(info.get("path")),
                                              size=info.get("size")).update({"server":"{},{}".format(exist_image.server,data_server)})
                else:
                    img:Image = Image(guid= uuid.uuid1().__str__(),
                                  overview_count=info.get("overview_count"),
                                  has_pyramid = 1 if info.get("overview_count")>0 else 0,
                                  raster_x_size=info["xy_size"][0],
                                  raster_y_size=info["xy_size"][1],
                                  cell_x_size = info.get("cell_x_size"),
                                  cell_y_size = abs(info.get("cell_y_size")),
                                  name=os.path.basename(info.get("path")),
                                  alias = os.path.basename(info.get("path")),
                                  extent=json.dumps(info["origin_extent"]),
                                  null_value=info.get("null_value"),
                                  server=data_server,
                                  path = os.path.normpath(info.get("path")),
                                  size=info.get("size"),
                                  crs = str(info.get("crs")),
                                  crs_wkt = info.get("crs_wkt"),
                                  crs_proj4= info.get("crs_proj4"),
                                  band_count=info.get("band_count"),
                                  band_view = "[1,2,3]" if info.get("band_count")>=3 else "[1,1,1]",
                                  create_time=this_time,
                                  update_time=this_time,
                                  type = ImageType.get_type(info.get("path"))
                                  )
                    for tag in tags:
                        img.image_tags.append(tag)
                    db.session.add(img)
            db.session.commit()
            res["result"] = True

        except Exception as e:
            raise e

        return res

    def recur_paths_local(self,paths):
        for path in paths:
            if path["type"].__eq__("dir"):
                data_list: list = []
                for f in os.listdir(path["path"]):
                    file_path = os.path.normpath(os.path.join(path["path"], f))
                    file_info = {"name": f, "path": file_path}
                    if file_path.lower().endswith("tiff") or file_path.lower().endswith("tif"):
                        file_info["type"] = "tif"
                        data_list.append(file_info)
                    elif file_path.lower().endswith("img"):
                        file_info["type"] = "img"
                        data_list.append(file_info)
                    elif os.path.isdir(file_path):
                        file_info["type"] = "dir"
                        data_list.append(file_info)
                for p in self.recur_paths_local(data_list):
                    yield p
            else:
                yield path

    def recur_paths_remote(self,paths,client):
        for path in paths:
            if path["type"].__eq__("dir"):
                path_list = json.loads(client.getImageList(path["path"]))
                for p in self.recur_paths_remote(path_list,client):
                    yield p
            else:
                yield path
    api_doc = {
        "tags": ["影像接口"],
        "parameters": [
            {"name": "data_server",
             "in": "formData",
             "type": "string",
             "description": "data_server"},
            {"name": "paths",
             "in": "formData",
             "type": "string",
             "description": "paths"},
            {"name": "tags",
             "in": "formData",
             "type": "string",
             "description": "tags以,相隔"}
        ],
        "responses": {
            200: {
                "schema": {
                    "properties": {
                    }
                }
            }
        }
    }