ImageData.py 5.5 KB
# coding=utf-8
#author:        4N
#createtime:    2021/10/15
#email:         nheweijun@sina.com

from app.util import *
from osgeo import gdal
from osgeo.gdal import *
from numpy import ndarray
import numpy
from app.modules.service.image.util.ThriftConnect import ThriftConnect,ThriftPool
import json
import gzip

class ImageData:


    def __init__(self,image_server,image):
        self.image_server = image_server
        self.image = image


    def get_data(self, extent, bands, height, width):
        if self.image_server.__eq__("本地服务器"):
            data = self.get_local_wms_data(extent, bands, height, width)
        elif self.image_server.__eq__("None"):
            data = numpy.zeros((height, width, 3), dtype=int) + 65536
        else:
            data = self.get_remote_wms_data(extent, bands, height, width)
        return data


    def get_remote_wms_data(self, extent, bands, height, width):
        '''
        通过RPC获取远程数据
        :param image:
        :param extent:
        :param bands:
        :return:
        '''

        # 需要做thrift连接的缓存,连接池
        thrift_connect = ThriftConnect(self.image_server)
        image_extent = self.image.extent

        data = thrift_connect.client.getData(self.image.path, extent, json.loads(image_extent), bands, width, height)

        thrift_connect.close()

        data = gzip.decompress(data)
        data = numpy.frombuffer(data, dtype='int64')
        data = data.reshape((height, width, 3))

        return data


    def get_remote_wms_data_cpp(self, image_server, image, extent, bands, height, width):
        '''
        通过RPC获取远程数据
        :param image:
        :param extent:
        :param bands:
        :return:
        '''

        # 需要做thrift连接的缓存,连接池
        thrift_connect = ThriftConnect(image_server)
        image_extent = image.extent

        data = thrift_connect.client.getData(image.path, extent, json.loads(image_extent), bands, width, height)

        thrift_connect.close()

        return data


    def get_local_wms_data(self, extent, bands, height, width):
        '''
        获取本地数据
        :param image:
        :param extent:
        :param bands:
        :return:
        '''

        img: Dataset = gdal.Open(self.image.path, 0)

        origin_extent = json.loads(self.image.extent)

        # 超出空间范围
        if extent[2] < origin_extent[0] or extent[0] > origin_extent[2] or extent[1] > origin_extent[3] or extent[
            3] < origin_extent[1]:
            empty = numpy.zeros((height, width, 3), dtype=int) + 65536
        # 空间范围相交
        else:

            ox = img.RasterXSize
            oy = img.RasterYSize

            # 网格大小
            grid_x = (origin_extent[2] - origin_extent[0]) / (ox * 1.0)
            grid_y = (origin_extent[3] - origin_extent[1]) / (oy * 1.0)

            # 完全在影像范围内
            if extent[0] > origin_extent[0] and extent[1] > origin_extent[1] and extent[2] < \
                    origin_extent[2] and extent[3] < origin_extent[3]:

                # 网格偏移量
                off_x = math.floor((extent[0] - origin_extent[0]) / grid_x)
                off_y = math.floor((origin_extent[3] - extent[3]) / grid_y)

                # 截取后网格个数
                x_g = math.ceil((extent[2] - extent[0]) / grid_x)

                y_g = math.ceil((extent[3] - extent[1]) / grid_y)

                empty = img.ReadRaster(off_x, off_y, x_g, y_g, 256, 256, band_list = bands)
                img.ReadAsArray()
            # 部分相交
            else:

                inter_extent = [0, 0, 0, 0]
                inter_extent[0] = origin_extent[0] if origin_extent[0] > extent[0] else extent[0]
                inter_extent[1] = origin_extent[1] if origin_extent[1] > extent[1] else extent[1]
                inter_extent[2] = origin_extent[2] if origin_extent[2] < extent[2] else extent[2]
                inter_extent[3] = origin_extent[3] if origin_extent[3] < extent[3] else extent[3]

                # 网格偏移量
                off_x = math.floor((inter_extent[0] - origin_extent[0]) / grid_x)
                off_y = math.floor((origin_extent[3] - inter_extent[3]) / grid_y)

                # 截取后网格个数
                x_g = math.floor((inter_extent[2] - inter_extent[0]) / grid_x)
                y_g = math.floor((inter_extent[3] - inter_extent[1]) / grid_y)

                # 相对于出图的偏移量

                # 出图的网格大小
                out_grid_x = (extent[2] - extent[0]) / (width * 1.0)
                out_grid_y = (extent[3] - extent[1]) / (height * 1.0)

                out_off_x = int(math.ceil((inter_extent[0] - extent[0]) / out_grid_x))
                out_off_y = int(math.ceil((extent[3] - inter_extent[3]) / out_grid_y))

                out_x_g = int(math.floor((inter_extent[2] - inter_extent[0]) / out_grid_x))
                out_y_g = int(math.floor((inter_extent[3] - inter_extent[1]) / out_grid_y))

                # 相交部分在出图的哪个位置

                overview_raster: ndarray = img.ReadAsArray(off_x, off_y, x_g, y_g, out_x_g,
                                                           out_y_g)

                dat = numpy.zeros((height, width, 3), dtype=int) + 65536
                dat[out_off_y:out_off_y + out_y_g, out_off_x:out_off_x + out_x_g] = overview_raster
                empty = dat

        return empty