image_wms_temporary.py 7.6 KB
# coding=utf-8
#author:        4N
#createtime:    2021/3/24
#email:         nheweijun@sina.com

import traceback

import numpy
from flask import Response
import random

from app.modules.service.models import Image
from app.util.component.ApiTemplate import ApiTemplate

from app.util.component.ParameterUtil import ParameterUtil
import json

from .util.MyThread import MyThread
from .util.ImageData import ImageData
from .util.Opencv import Opencv
from .util.Cache import Cache
from app.util.component.ModelVisitor import ModelVisitor
from app.util.component.ParameterUtil import StructurePrint
class Api(ApiTemplate):

    api_name = "发布服务时预览"

    def process(self):


        result = {}
        parameter: dict = self.para

        try:

            parameter = ParameterUtil.to_lower(parameter)

            image_guids = parameter.get("image_guids")
            
            get_extent = parameter.get("get_extent")
            if get_extent and (get_extent == True or get_extent.lower().__eq__("true")):

                sr_set = set()


                tmp_extent = []
                for g in image_guids.split(","):
                    image = Image.query.filter_by(guid=g).one_or_none()
                    if image:
                        image_extent = json.loads(image.extent)
                        if not tmp_extent:
                            tmp_extent = image_extent
                        else:
                            tmp_extent[0] = min(image_extent[0], tmp_extent[0])
                            tmp_extent[2] = max(image_extent[2], tmp_extent[2])
                            tmp_extent[1] = min(image_extent[1], tmp_extent[1])
                            tmp_extent[3] = max(image_extent[3], tmp_extent[3])
                    sr_set.add(image.crs)
                if len(sr_set)>1:
                    result["result"] = False
                    result["msg"] = "影像坐标不一致"
                    return result

                result["result"] = True
                result["data"] = tmp_extent
                return result

            bbox = parameter.get("bbox")
            width = int(parameter.get("width")) if parameter.get("width") else 256
            height = int(parameter.get("height")) if parameter.get("height") else 256
            image_type = parameter.get("format") if parameter.get("format") else "image/png"
            quality = int(parameter.get("quality")) if parameter.get("quality") else 30

            image_service_info ,zoo, servers = Cache.cache_data(None)

            extent = [float(x) for x in bbox.split(",")]

            images = Image.query.filter(Image.guid.in_(image_guids.split(","))).all()

            intersect_image = [im for im in images if self.determin_intersect(json.loads(im.extent),extent)]

            if len(intersect_image)>1:

                # 结果矩阵
                empty_list = [numpy.zeros((height,width), dtype=int) + 65536,
                              numpy.zeros((height,width), dtype=int) + 65536,
                              numpy.zeros((height,width), dtype=int) + 65536]

                pixel_array = numpy.zeros((height,width,3), dtype=int)
                thread_list = []

                for image in intersect_image:
                    #该影像的服务器,随机选取一个
                    image_servers = image.server.split(",")
                    image_servers = [ser for ser in image_servers if ser in servers]
                    if len(image_servers)>0:
                        indx = int(random.random() * len(image_servers))
                        image_server = image_servers[indx]
                    else:
                        image_server = "None"
                    bands = json.loads(image.band_view)

                    image_data = ImageData(image_server, ModelVisitor.object_to_json(image))

                    thread: MyThread = MyThread(image_data.get_data, args=(extent,bands,height,width))
                    thread.start()
                    thread_list.append(thread)


                for thread in thread_list:
                    thread.join()
                    data = thread.get_result()

                    # 掩膜在中央接口生成,合图
                    mask = numpy.zeros((height,width), dtype=int)
                    mask2 = numpy.zeros((height,width), dtype=int)
                    jizhun = data[:, :, 0]
                    mask[jizhun == 65536] = 1
                    mask[jizhun != 65536] = 0
                    mask2[jizhun == 65536] = 0
                    mask2[jizhun != 65536] = 1
                    # 掩膜计算
                    for i, d in enumerate(empty_list):
                        empty_list[i] = empty_list[i] * mask + data[:, :, i] * mask2

                for ii in [0, 1, 2]:
                    # opencv 颜色排序为GBR
                    pixel_array[:, :, 2 - ii] = empty_list[ii]


            elif len(intersect_image)==1:
                # 该影像的服务器,随机选取一个
                image = intersect_image[0]
                image_servers = image.server.split(",")
                image_servers = [ser for ser in image_servers if ser in servers]
                if len(image_servers) > 0:
                    indx = int(random.random() * len(image_servers))
                    image_server = image_servers[indx]
                else:
                    image_server = "None"

                bands = json.loads(image.band_view)

                image_data = ImageData(image_server,ModelVisitor.object_to_json(image))

                pixel_array_t = image_data.get_data(extent,bands,height,width)
                pixel_array = numpy.zeros((height, width, 3), dtype=int)
                for ii in [0, 1, 2]:
                    # opencv 颜色排序为GBR
                    pixel_array[:, :, 2 - ii] = pixel_array_t[:, :, ii]
            else:
                # 结果矩阵
                pixel_array = numpy.zeros((height, width, 3), dtype=int)+65536

            # 将图片生成在内存中,然后直接返回response
            im_data = Opencv.create_image(image_type, pixel_array, quality)

            if self.para.get("overview"):
                return pixel_array
            return Response(im_data, mimetype=image_type.lower())

        except Exception as e:
            raise e


    def determin_intersect(self, extent1, extent2):
        if extent2[2] < extent1[0] or extent2[0] > extent1[2] or extent2[1] > extent1[
            3] or extent2[3] < extent1[1]:
            return False
        else:
            return True

    api_doc = {
        "tags": ["影像接口"],
        "parameters": [
            {"name": "image_guids",
             "in": "query",
             "type": "string"},
            {"name": "get_extent",
             "in": "query",
             "type": "boolean",
             "enum":[True,False]},
            {"name": "bbox",
             "in": "query",
             "type": "string"},
            {"name": "width",
             "in": "query",
             "type": "string"},
            {"name": "height",
             "in": "query",
             "type": "string"},
            {"name": "format",
             "in": "query",
             "type": "string"},
            {"name": "quality",
             "in": "query",
             "type": "string"}
        ],
        "responses": {
            200: {
                "schema": {
                    "properties": {
                    }
                }
            }
        }
    }