Pytorch Api 实现直方图均衡化

pytorch-api-实现直方图均衡化

完整代码

import torch
import numpy as np
from torch import nn
from PIL import Image
import torch.nn.functional as F
import os

def histogram_equalization(image_tensor):

    """对图像进行直方图均衡化
    Args:
    image_tensor (tensor): 输入的tensor图像, 是单通道的灰度图像

    Returns:
        tensor: 输出的tensor图像

    """

    # 将图像转换为灰度图像
    # image_tensor = 0.299*image[:,0,:,:] + 0.587*image[:,1,:,:] + 0.114*image[:,2,:,:]
    # 计算灰度图像的直方图
    hist = torch.histc(image_tensor, bins=256, min=0, max=255)
    # 计算灰度图像的累积分布函数
    cdf = hist.cumsum(dim=0)
    # 归一化累积分布函数
    cdf = (cdf - cdf.min()) * 255 / (cdf.max() - cdf.min())
    # 对灰度图像进行直方图均衡化
    equalized_image = torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)).float()
    # 将均衡化后的图像转换为 RGB 图像
    # equalized_image = torch.stack([equalized_image, equalized_image, equalized_image], dim=1)
    return equalized_image.reshape(image_tensor.shape)

def main():
    # 把工作目录切换到当前文件夹所在目录
    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    # 打印当前工作目录
    print(os.getcwd()) 
    # 读入一张图片,并转换为灰度图
    # im = Image.open('./test2.jpg').convert('L')
    im = Image.open('./test.png').convert('L')
    # 将图片数据转换为矩阵
    im = np.array(im, dtype='float32')
    # 将图片矩阵转换为pytorch tensor,并适配卷积输入的要求
    im = torch.from_numpy(im)
    print("im.shape: ", im.shape)

    img_tensor = histogram_equalization(im.squeeze().detach())

    print("img_tensor after equalization shape:", img_tensor.shape)
    print("img_tensor after equalization : ", img_tensor)

    # edge_detect 阈值, 将小于阈值的元素设置为0
    img_tensor[img_tensor < 200] = 0

    img_tensor = img_tensor.squeeze().detach().numpy()
    print("img_tensor to numpy : ", img_tensor, type(img_tensor))

    # 将array数据转换为image
    im_image = Image.fromarray(img_tensor)
    
    # image数据转换为灰度模式
    im_image_L = im_image.convert('L')

    # 将Image数据转换为numpy array
    im_L_numpy = np.array(im_image, dtype='uint8')
    print("im_L_numpy.shape: ", im_L_numpy.shape)
    print("im_L_numpy: ", im_L_numpy)

    # 保存图片
    im_image_L.save('res.png', quality=95)
 
if __name__ == "__main__":
    main()

histogram_equalization 函数解释:

  • hist = torch.histc(image_tensor, bins=256, min=0, max=255)     1. torch.histc()函数的作用是计算张量的直方图     2. bins: 直方图的柱数,即区间的个数,bins=256表示将0-255的像素值分为256个区间,        每个区间的宽度为1,即0-1, 1-2, 2-3, …, 254-255, 255-256,共256个区间     3. min: 统计的最小值,max: 统计的最大值     5. hist: 直方图的统计结果,是一个一维的tensor,长度为bins,即256,每个元素表示该区间的像素值的个数

  • cdf = hist.cumsum(dim=0)     1. hist.cumsum(dim=0) 表示对hist的每个元素进行累加,dim=0表示按照第0维进行累加,        例如:即将hist的第0个元素与第1个元素相加,结果作为第1个元素,再将第1个元素与第2个元素相加,结果作为第2个元素,以此类推。     2. cdf: 累积分布函数,是一个一维的tensor,长度为bins,即256,第x个元素表示区间[0,x]的像素值的个数.

  • cdf = (cdf - cdf.min()) * 255 / (cdf.max() - cdf.min())     1. 将cdf的值归一化到0-255之间     2. cdf.min() 表示cdf中的最小值, cdf.max() 表示cdf中的最大值, (cdf - cdf.min()) 表示cdf中的每个元素减去cdf中的最小值, (cdf.max() - cdf.min()) 表示cdf中的最大值减去cdf中的最小值     3. (cdf - cdf.min()) * 255 / (cdf.max() - cdf.min()) 表示将cdf中的每个元素减去cdf中的最小值, 再乘以255, 再除以(cdf.max() - cdf.min())

  • equalized_image = torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)).float()     1. cdf.to(torch.int64) 将cdf的数据类型转换为torch.int64     2. image_tensor.to(torch.int64).reshape(-1) 将image_tensor的数据类型转换为   torch.int64,并将其转换为一维的tensor     3. torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)) 从cdf中按照image_tensor的值进行索引     4. equalized_image = torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)).float() 将索引得到的值转换为float类型    

  • equalized_image.reshape(image_tensor.shape)     1. 将equalized_image的形状转换为image_tensor的形状。

torch.gather 这个函数是什么意思?

torch.gather() 是 PyTorch 中的一个函数,用于在一个输入 Tensor 中按照指定的维度和索引聚合数据。具体来说,torch.gather() 函数的功能是从输入 Tensor 中按照指定的维度和索引收集数据,并返回一个新的 Tensor。

torch.gather() 函数的基本用法如下:

torch.gather(input, dim, index, out=None)

其中,input 表示输入的 Tensor,dim 表示指定的维度,index 表示索引,out 表示输出的 Tensor,如果不指定则会创建一个新的 Tensor。

下面是一个最简单的 torch.gather() 的例子,假设我们有一个 2x3 的 Tensor:

import torch

x = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])

我们希望从第一个维度(即行)中选择第 0 行和第 1 行的元素,可以使用以下代码:

indices = torch.tensor([
    [0, 2, 1],
    [1, 0, 2]
])

result = torch.gather(x, 1, indices)

其中,indices 表示要选择的索引,1 表示要在第一个维度上进行聚合。运行上述代码后,result 的值为:

tensor([[1, 3, 2],
        [5, 4, 6]])

可以看到,result 的第一行表示从 x 的第一行中选择了索引为 0、2、1 的元素,第二行表示从 x 的第二行中选择了索引为 1、0、2 的元素。

需要注意的是,torch.gather() 函数的输入和输出 Tensor 的形状必须满足一定的条件,具体可以参考 PyTorch 官方文档。

实验效果

输入:test.png

输出:res.png

阈值200后的输出:img_tensor[img_tensor < 200] = 0