Pytorch Api 实现直方图均衡化
pytorch-api-实现直方图均衡化
完整代码
import torch
import numpy as np
from torch import nn
from PIL import Image
import torch.nn.functional as F
import os
def histogram_equalization(image_tensor):
"""对图像进行直方图均衡化
Args:
image_tensor (tensor): 输入的tensor图像, 是单通道的灰度图像
Returns:
tensor: 输出的tensor图像
"""
# 将图像转换为灰度图像
# image_tensor = 0.299*image[:,0,:,:] + 0.587*image[:,1,:,:] + 0.114*image[:,2,:,:]
# 计算灰度图像的直方图
hist = torch.histc(image_tensor, bins=256, min=0, max=255)
# 计算灰度图像的累积分布函数
cdf = hist.cumsum(dim=0)
# 归一化累积分布函数
cdf = (cdf - cdf.min()) * 255 / (cdf.max() - cdf.min())
# 对灰度图像进行直方图均衡化
equalized_image = torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)).float()
# 将均衡化后的图像转换为 RGB 图像
# equalized_image = torch.stack([equalized_image, equalized_image, equalized_image], dim=1)
return equalized_image.reshape(image_tensor.shape)
def main():
# 把工作目录切换到当前文件夹所在目录
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# 打印当前工作目录
print(os.getcwd())
# 读入一张图片,并转换为灰度图
# im = Image.open('./test2.jpg').convert('L')
im = Image.open('./test.png').convert('L')
# 将图片数据转换为矩阵
im = np.array(im, dtype='float32')
# 将图片矩阵转换为pytorch tensor,并适配卷积输入的要求
im = torch.from_numpy(im)
print("im.shape: ", im.shape)
img_tensor = histogram_equalization(im.squeeze().detach())
print("img_tensor after equalization shape:", img_tensor.shape)
print("img_tensor after equalization : ", img_tensor)
# edge_detect 阈值, 将小于阈值的元素设置为0
img_tensor[img_tensor < 200] = 0
img_tensor = img_tensor.squeeze().detach().numpy()
print("img_tensor to numpy : ", img_tensor, type(img_tensor))
# 将array数据转换为image
im_image = Image.fromarray(img_tensor)
# image数据转换为灰度模式
im_image_L = im_image.convert('L')
# 将Image数据转换为numpy array
im_L_numpy = np.array(im_image, dtype='uint8')
print("im_L_numpy.shape: ", im_L_numpy.shape)
print("im_L_numpy: ", im_L_numpy)
# 保存图片
im_image_L.save('res.png', quality=95)
if __name__ == "__main__":
main()
histogram_equalization
函数解释:
hist = torch.histc(image_tensor, bins=256, min=0, max=255)
1. torch.histc()函数的作用是计算张量的直方图 2. bins: 直方图的柱数,即区间的个数,bins=256表示将0-255的像素值分为256个区间, 每个区间的宽度为1,即0-1, 1-2, 2-3, …, 254-255, 255-256,共256个区间 3. min: 统计的最小值,max: 统计的最大值 5. hist: 直方图的统计结果,是一个一维的tensor,长度为bins,即256,每个元素表示该区间的像素值的个数cdf = hist.cumsum(dim=0)
1. hist.cumsum(dim=0) 表示对hist的每个元素进行累加,dim=0表示按照第0维进行累加, 例如:即将hist的第0个元素与第1个元素相加,结果作为第1个元素,再将第1个元素与第2个元素相加,结果作为第2个元素,以此类推。 2. cdf: 累积分布函数,是一个一维的tensor,长度为bins,即256,第x个元素表示区间[0,x]的像素值的个数.cdf = (cdf - cdf.min()) * 255 / (cdf.max() - cdf.min())
1. 将cdf的值归一化到0-255之间 2. cdf.min() 表示cdf中的最小值, cdf.max() 表示cdf中的最大值, (cdf - cdf.min()) 表示cdf中的每个元素减去cdf中的最小值, (cdf.max() - cdf.min()) 表示cdf中的最大值减去cdf中的最小值 3. (cdf - cdf.min()) * 255 / (cdf.max() - cdf.min()) 表示将cdf中的每个元素减去cdf中的最小值, 再乘以255, 再除以(cdf.max() - cdf.min())equalized_image = torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)).float()
1. cdf.to(torch.int64) 将cdf的数据类型转换为torch.int64 2. image_tensor.to(torch.int64).reshape(-1) 将image_tensor的数据类型转换为 torch.int64,并将其转换为一维的tensor 3. torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)) 从cdf中按照image_tensor的值进行索引 4. equalized_image = torch.gather(cdf.to(torch.int64), 0, image_tensor.to(torch.int64).reshape(-1)).float() 将索引得到的值转换为float类型equalized_image.reshape(image_tensor.shape)
1. 将equalized_image的形状转换为image_tensor的形状。
torch.gather 这个函数是什么意思?
torch.gather()
是 PyTorch 中的一个函数,用于在一个输入 Tensor 中按照指定的维度和索引聚合数据。具体来说,torch.gather()
函数的功能是从输入 Tensor 中按照指定的维度和索引收集数据,并返回一个新的 Tensor。
torch.gather()
函数的基本用法如下:
torch.gather(input, dim, index, out=None)
其中,input
表示输入的 Tensor,dim
表示指定的维度,index
表示索引,out
表示输出的 Tensor,如果不指定则会创建一个新的 Tensor。
下面是一个最简单的 torch.gather()
的例子,假设我们有一个 2x3
的 Tensor:
import torch
x = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
我们希望从第一个维度(即行)中选择第 0 行和第 1 行的元素,可以使用以下代码:
indices = torch.tensor([
[0, 2, 1],
[1, 0, 2]
])
result = torch.gather(x, 1, indices)
其中,indices
表示要选择的索引,1
表示要在第一个维度上进行聚合。运行上述代码后,result
的值为:
tensor([[1, 3, 2],
[5, 4, 6]])
可以看到,result
的第一行表示从 x
的第一行中选择了索引为 0、2、1 的元素,第二行表示从 x
的第二行中选择了索引为 1、0、2 的元素。
需要注意的是,torch.gather()
函数的输入和输出 Tensor 的形状必须满足一定的条件,具体可以参考 PyTorch 官方文档。
实验效果
输入:test.png
输出:res.png
阈值200后的输出:img_tensor[img_tensor < 200] = 0