import os
import platform
import random
from typing import List
import numpy as np
import torch
from torch.utils.data import Dataset
from denoise.data.pointcloud import PointCloudBase, Cache, PointCloudPatch, PointCloudFeature
class PointCloudPatchDataset(Dataset):
def __init__(self,
point_cloud_name_path: str,
indir: str,
target_features: List[str],
points_per_patch: int,
use_pca: bool,
radius: List[float],
center: str,
point_tuple: int,
cache_capacity: int,
seed: int,
patch_sampling: str):
self.point_tuple: int = point_tuple
self.point_cloud_name_path: str = point_cloud_name_path
self.indir = indir
self.point_cloud_names: List[str] = self.get_point_cloud_list()
# 点云的基本信息,常驻内存。但是数据就通过cache进行读取
self.point_cloud_base_info: List[PointCloudBase] = []
# 点云所需特征
self.point_cloud_features: PointCloudFeature = PointCloudFeature(target_features, use_pca)
self.radius: List[float] = radius
# 把数据保存成为npy的形式,并且存储一些点云的基本信息
self.save()
self.pred_dim: int = self.point_cloud_features.dim
self.cache: Cache = Cache(cache_capacity)
self.points_per_patch: int = points_per_patch
self.center: str = center
self.rng: np.random.RandomState = np.random.RandomState(seed)
self.patch_sampling = patch_sampling
def get_point_cloud_list(self):
# 因为文件上传到操作系统之后会存在命名问题
strange_str = ""
if "Linux" == platform.system():
# strange_str="2%"
strange_str = ""
elif "Windows" == platform.system():
strange_str = ""
else:
raise Exception("错误的操作系统")
with open(self.point_cloud_name_path) as f:
lines = f.readlines()
point_cloud_names = [line.strip() + strange_str for line in lines]
def need(line):
return line is not None and line != strange_str
point_cloud_names = list(filter(need, point_cloud_names))
return point_cloud_names
def outliers_generator_key_points(self, points, outliers, point_cloud_name):
point_indices = np.array(range(len(points)))
original_points = point_indices[outliers == 0]
outliers_points = point_indices[outliers == 1]
points_num = min(len(original_points), len(outliers_points))
ratio = 1.0
# 生成数量一致的原点云和离群点
selected_original_points = random.sample(list(original_points), int(points_num * ratio))
selected_outliers_points = random.sample(list(outliers_points), int(points_num * ratio))
key_points = selected_original_points + selected_outliers_points
random.shuffle(key_points)
np.savetxt(os.path.join(self.indir, point_cloud_name + ".pidx"), key_points)
def save_to_npy(self, point_cloud_name, file_suffix, file_type):
file_name = os.path.join(self.indir, point_cloud_name + file_suffix)
if not os.path.exists(file_name):
return None
# 切记 如果修改了读取点云的配置
# 就要把这里改掉
file_suffix_list = ['pidx']
file_suffix_list = []
# 这里第三个条件是我们全流程都要重新生成不然用的肯定是以前的结果
if os.path.exists(file_name + ".npy") \
and (file_suffix not in file_suffix_list) \
and self.indir.find("whole_process") == -1:
return np.load(file_name + ".npy")
point_cloud_data_npy = np.loadtxt(file_name).astype(file_type)
np.save(file_name + '.npy', point_cloud_data_npy)
return point_cloud_data_npy
def save(self):
for index, point_cloud_name in enumerate(self.point_cloud_names):
print(f"getting information from {point_cloud_name}...")
points = self.save_to_npy(point_cloud_name, ".xyz", "float32")
if points is None:
raise ValueError(f"{self.indir}")
outliers = None
if self.point_cloud_features.include_normal:
self.save_to_npy(point_cloud_name, ".normals", "float32")
if self.point_cloud_features.include_outliers:
outliers = self.save_to_npy(point_cloud_name, ".outliers", "float32")
if self.point_cloud_features.include_curvatures:
self.save_to_npy(point_cloud_name, ".curv", "float32")
if self.point_cloud_features.include_clean_points:
self.save_to_npy(point_cloud_name, ".clean_xyz", "float32")
if self.point_cloud_features.include_key_points:
patch_count = len(self.save_to_npy(point_cloud_name, ".pidx", "int32"))
# # 要有pidx这个文件大家都会有,要是没有大家就都没有
# if key_points_indices is not None:
# self.point_cloud_features.include_key_points=True
# patch_count = len(key_points_indices)
# 如果存在离群点,但是本身又没有pidx的文件信息
elif self.point_cloud_features.include_outliers and not self.point_cloud_features.include_key_points:
self.outliers_generator_key_points(points, outliers, point_cloud_name)
key_points_indices = self.save_to_npy(point_cloud_name, ".pidx", "int32")
patch_count = len(key_points_indices)
else:
patch_count = points.shape[0]
base_info = PointCloudBase(self.indir, index, point_cloud_name, patch_count, points, self.radius)
self.point_cloud_base_info.append(base_info)
if self.point_cloud_features.include_outliers:
self.point_cloud_features.include_key_points = True
print(f"data load finish... all point cloud are {len(self.point_cloud_names)}")
def __len__(self):
return sum([point_cloud_base.patch_count for point_cloud_base in self.point_cloud_base_info])
def get_point_cloud_and_patch_by_dataset_index(self, index):
begin = 0
for i, item in enumerate(self.point_cloud_base_info):
patch_count = item.patch_count
if begin <= index < patch_count + begin:
return i, index - begin
begin += patch_count
raise ValueError(f"不存在这样的一个index坐标!{index}")
def __getitem__(self, index):
# 拿到点云id和点云片id
id, patch_id = self.get_point_cloud_and_patch_by_dataset_index(index)
point_cloud = self.cache.get(self.point_cloud_base_info[id])
if self.point_cloud_features.include_key_points:
# 如果存在关键点,那么取到的是关键点的索引
center = point_cloud.key_points[patch_id]
else:
center = patch_id
patch_radii = point_cloud.base_info.patch_radii
final_point_cloud_patch_list = []
for k, patch_radius in enumerate(patch_radii):
center_point = point_cloud.points[center]
point_cloud_patch = self.get_point_cloud_patch_by_center(point_cloud.points,
point_cloud.kdtree,
center_point,
patch_radius)
final_point_cloud_patch_list.append(point_cloud_patch)
final_point_cloud_patch_features_list = []
if self.point_cloud_features.include_normal:
patch_normal = point_cloud.normals[center]
final_point_cloud_patch_features_list.append(patch_normal)
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
基于深度学习的三维点云去噪.zip (39个子文件)
3D-point-cloud-denoising-main
denoise
__init__.py 0B
eval.py 3KB
data
datasampler.py 3KB
pointcloud.py 8KB
dataset.py 16KB
data_preprocess.py 4KB
noise_removal
loss.py 2KB
model
pcpnet.py 11KB
respcpnet.py 9KB
train_config.py 6KB
eval_config.py 3KB
utils
net_util.py 1KB
file_util.py 621B
train_util.py 3KB
seed_util.py 215B
point_cloud_util.py 1KB
device_util.py 1KB
outliers_removal
loss.py 3KB
model
pcpnet.py 11KB
respcpnet.py 9KB
train_config.py 6KB
eval_config.py 3KB
process
preprocess.py 7KB
process.py 10KB
common
core.py 559B
outliers_removal_config.json 261B
config.json 42B
noise_removal_config.json 246B
config.py 3KB
whole_process.py 10KB
train.py 3KB
addnoise
__init__.py 4KB
util.py 224B
environments.py 1KB
eval_indictor
calculate_fscore.py 683B
__init__.py 0B
util.py 529B
distance.py 2KB
score.py 5KB
共 39 条
- 1
资源评论
博士僧小星
- 粉丝: 1947
- 资源: 5905
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功