ImageData.py 7.4 KB
# coding=utf-8
#author:        4N
#createtime:    2021/10/15
#email:         nheweijun@sina.com

from app.util import *
from osgeo import gdal
from osgeo.gdal import *
from numpy import ndarray
import numpy
from app.modules.service.image.util.ThriftConnect import ThriftConnect,ThriftPool
import json
import gzip

class ImageData:


    def __init__(self,image_server,image):
        self.image_server = image_server
        self.image = image


    def get_data(self, extent, bands, height, width):
        if self.image_server.__eq__("本地服务器"):
            data = self.get_local_wms_data(extent, bands, height, width)
        elif self.image_server.__eq__("None"):
            data = numpy.zeros((height, width, 3), dtype=int) + 65536
        else:
            data = self.get_remote_wms_data(extent, bands, height, width)
        return data


    def get_remote_wms_data(self, extent, bands, height, width):
        '''
        通过RPC获取远程数据
        :param image:
        :param extent:
        :param bands:
        :return:
        '''

        # 需要做thrift连接的缓存,连接池
        thrift_connect = ThriftConnect(self.image_server)
        image_extent = self.image.get("extent")

        data = thrift_connect.client.getData(self.image.get("path"), extent, json.loads(image_extent), bands, width, height)

        thrift_connect.close()

        data = gzip.decompress(data)
        data = numpy.frombuffer(data, dtype='int64')
        data = data.reshape((height, width, 3))

        return data


    def get_remote_wms_data_cpp(self, image_server, image, extent, bands, height, width):
        '''
        通过RPC获取远程数据
        :param image:
        :param extent:
        :param bands:
        :return:
        '''

        # 需要做thrift连接的缓存,连接池
        thrift_connect = ThriftConnect(image_server)
        image_extent = image.get("extent")

        data = thrift_connect.client.getData(image.get("path"), extent, json.loads(image_extent), bands, width, height)

        thrift_connect.close()

        return data


    def get_local_wms_data(self, extent, bands, height, width):
        '''
        获取本地数据
        :param image:
        :param extent:
        :param bands:
        :return:
        '''
        pixel_array = numpy.zeros((height, width, 3), dtype=int)
        ceng = 0
        img: Dataset = gdal.Open(self.image.get("path"), 0)

        for band in bands:

            # 自决定金字塔等级
            xysize = [img.RasterXSize, img.RasterYSize]

            origin_extent = json.loads(self.image.get("extent"))
            band_data: Band = img.GetRasterBand(band)

            max_level = band_data.GetOverviewCount()

            # 超出空间范围
            if extent[2] < origin_extent[0] or extent[0] > origin_extent[2] or extent[1] > origin_extent[
                3] or extent[3] < origin_extent[1]:
                empty = numpy.zeros((height, width), dtype=int) + 65536
            # 空间范围相交
            else:
                image_level = self.determine_level(xysize, origin_extent, extent, max_level)

                if image_level == -1:
                    overview = band_data
                else:
                    try:
                        overview: Band = band_data.GetOverview(image_level)
                    except:
                        raise Exception("该影像不存在该级别的金字塔数据!")
                ox = overview.XSize
                oy = overview.YSize

                # 网格大小
                grid_x = (origin_extent[2] - origin_extent[0]) / (ox * 1.0)
                grid_y = (origin_extent[3] - origin_extent[1]) / (oy * 1.0)

                # 完全在影像范围内
                if extent[0] > origin_extent[0] and extent[1] > origin_extent[1] and extent[2] < \
                        origin_extent[2] and extent[3] < origin_extent[3]:

                    # 网格偏移量
                    off_x = math.floor((extent[0] - origin_extent[0]) / grid_x)
                    off_y = math.floor((origin_extent[3] - extent[3]) / grid_y)

                    # 截取后网格个数
                    x_g = math.ceil((extent[2] - extent[0]) / grid_x)

                    y_g = math.ceil((extent[3] - extent[1]) / grid_y)

                    empty = overview.ReadAsArray(off_x, off_y, x_g, y_g, width, height)


                # 部分相交
                else:

                    inter_extent = [0, 0, 0, 0]
                    inter_extent[0] = origin_extent[0] if origin_extent[0] > extent[0] else extent[0]
                    inter_extent[1] = origin_extent[1] if origin_extent[1] > extent[1] else extent[1]
                    inter_extent[2] = origin_extent[2] if origin_extent[2] < extent[2] else extent[2]
                    inter_extent[3] = origin_extent[3] if origin_extent[3] < extent[3] else extent[3]

                    # 网格偏移量
                    off_x = math.floor((inter_extent[0] - origin_extent[0]) / grid_x)
                    off_y = math.floor((origin_extent[3] - inter_extent[3]) / grid_y)

                    # 截取后网格个数
                    x_g = math.floor((inter_extent[2] - inter_extent[0]) / grid_x)
                    y_g = math.floor((inter_extent[3] - inter_extent[1]) / grid_y)

                    # 相对于出图的偏移量

                    # 出图的网格大小
                    out_grid_x = (extent[2] - extent[0]) / (width * 1.0)
                    out_grid_y = (extent[3] - extent[1]) / (height * 1.0)

                    out_off_x = int(math.ceil((inter_extent[0] - extent[0]) / out_grid_x))
                    out_off_y = int(math.ceil((extent[3] - inter_extent[3]) / out_grid_y))

                    out_x_g = int(math.floor((inter_extent[2] - inter_extent[0]) / out_grid_x))
                    out_y_g = int(math.floor((inter_extent[3] - inter_extent[1]) / out_grid_y))

                    # 相交部分在出图的哪个位置

                    overview_raster: ndarray = overview.ReadAsArray(off_x, off_y, x_g, y_g, out_x_g,
                                                                    out_y_g)

                    dat = numpy.zeros((height, width), dtype=int) + 65536
                    dat[out_off_y:out_off_y + out_y_g, out_off_x:out_off_x + out_x_g] = overview_raster

                    empty = dat

            pixel_array[:, :, ceng] = empty
            ceng += 1
        return pixel_array


    def determine_level(self, xysize, origin_extent, extent, max_level):
        '''
        根据范围判断调用金字塔的哪一层
        :param xysize:
        :param origin_extent:
        :param extent:
        :param max_level:
        :return:
        '''
        x = xysize[0]
        y = xysize[1]
        level = -1
        pixel = x * y * (((extent[2] - extent[0]) * (extent[3] - extent[1])) / (
                (origin_extent[2] - origin_extent[0]) * (origin_extent[3] - origin_extent[1])))
        while pixel > 100000 and level < max_level - 1:
            level += 1
            x = x / 2
            y = y / 2
            pixel = x * y * (((extent[2] - extent[0]) * (extent[3] - extent[1])) / (
                    (origin_extent[2] - origin_extent[0]) * (origin_extent[3] - origin_extent[1])))
        return level