import json
import os
import random
import warnings
from collections import OrderedDict
from typing import Dict, List, Callable, Union, Optional
import numpy as np
import tqdm
from .meta import Meta, Col, Voc
from .vocab import Vocab
from .vocabs import Vocabs
class UniDep:
VER = Meta.VER
def __init__(self, store_dir, silent=False):
self.store_dir = os.path.expanduser(store_dir)
self.meta = Meta(self.store_dir)
self.silent = silent
self.cached = False
self.cached_samples = []
self.data_path = os.path.join(self.store_dir, 'data.npy')
self.data = np.load(self.data_path, allow_pickle=True)
try:
# noinspection PyTypeChecker
self.data: dict = self.data.item()
except Exception as err:
print(err)
return
self.cols = self.meta.cols # type: Dict[str, Col]
self.vocs = self.meta.vocs # type: Dict[str, Voc]
self.id_col = self.meta.id_col
self.id_voc = self.cols[self.id_col].voc
self._indexes = []
self.sample_size = -1
self.set_sample_size(self.id_voc.size)
self._sample_size = len(self.data[self.id_col])
if self.sample_size != self._sample_size:
self.set_sample_size(self._sample_size)
self.vocabs = Vocabs()
for vocab_name in self.vocs:
self.vocabs.append(Vocab(name=vocab_name).load(self.store_dir))
for voc in self.vocs:
self.vocs[voc].vocab = self.vocabs[voc]
self.id2index = self.vocabs[self.id_voc.name].o2i
self.unions = OrderedDict() # type: Dict[str, List[UniDep]]
self._deep_union = False
def set_sample_size(self, size):
modify_flag = self.sample_size > -1
self.sample_size = size
self._indexes = list(range(self.sample_size))
if modify_flag:
self.print('modify sample_size to', self.sample_size)
else:
self.print(f'loaded {self.sample_size} samples from {self.store_dir}')
def print(self, *args, **kwargs):
"""
silent-aware printer
"""
if self.silent:
return
print(*args, **kwargs)
def pack_sample(self, index) -> dict:
"""
pack sample into dict by raw index (data index)
"""
if self.cached:
return self.cached_samples[index]
sample = dict()
for col_name in self.data:
sample[col_name] = self.data[col_name][index]
if self._deep_union:
return sample
for col_name in self.unions:
col_value = sample[col_name]
for depot in self.unions[col_name]:
sample.update(depot[col_value])
return sample
def get_sample_by_id(self, obj_id):
return self.pack_sample(self.id2index[obj_id])
def start_caching(self):
"""
cache all samples into memory
"""
if self.cached:
return
self.cached = False
self.cached_samples = [None] * self._sample_size
for sample in tqdm.tqdm(self, disable=self.silent):
self.cached_samples[sample[self.id_col]] = sample
self.cached = True
def __getitem__(self, index):
index = self._indexes[index]
return self.pack_sample(index)
def __iter__(self):
"""vocab obj list iterator"""
for i in range(len(self)):
yield self[i]
def __len__(self):
return self.sample_size
def __str__(self):
""" UniDep (dir):
Sample Size: 1000
Id Column: id
Columns:
id, vocab index (size 1000)
text, vocab eng (size 30522), max length 100
label, vocab label (size 2)
"""
introduction = f"""
UniDep ({self.meta.parse_version(self.meta.version)}): {self.store_dir}
Sample Size: {self.sample_size}
Id Column: {self.id_col}
Columns:\n"""
for col_name, col in self.cols.items(): # type: str, Col
introduction += f' \t{col_name}, vocab {col.voc.name} (size {col.voc.size})'
if col.max_length:
introduction += f', max length {col.max_length}'
introduction += '\n'
return introduction
def __repr__(self):
return str(self)
"""
Advanced methods, including union, filter
"""
@staticmethod
def _merge(d1: dict, d2: dict) -> dict:
d = d1.copy()
d.update(d2)
return d
@classmethod
def _merge_cols(cls, c1: Dict[str, Col], c2: Dict[str, Col]) -> Dict[str, Col]:
for name, col in c2.items():
if name in c1 and c1[name] != col:
raise ValueError(f'col {name} config conflict')
return cls._merge(c1, c2)
@classmethod
def _merge_vocs(cls, v1: Dict[str, Voc], v2: Dict[str, Voc]) -> Dict[str, Voc]:
merged = v1.copy()
for name, vocab in v2.items():
if name in v1:
if v1[name] != vocab:
raise ValueError(f'vocab {name} config conflict')
vocab = v1[name].merge(vocab)
merged[name] = vocab
return merged
def deep_union(self, value):
if self._deep_union != value and self.unions:
raise ValueError('deep_union can not be changed after union-ed')
self._deep_union = value
def union(self, *depots: 'UniDep'):
"""
union depots, where id columns in other depots must exist in current main depot
"""
for depot in depots:
# check if id col exists in current depot
if depot.id_col not in self.cols:
raise ValueError('current depot has no column named {}'.format(depot.id_col))
if depot.id_col not in self.unions:
self.unions[depot.id_col] = []
self.unions[depot.id_col].append(depot)
self.cols = self._merge_cols(self.cols, depot.cols)
self.vocs = self._merge_vocs(self.vocs, depot.vocs)
self.meta.cols = self.cols
self.meta.vocs = self.vocs
if not self._deep_union:
continue
columns = {col_name: [] for col_name in depot.cols}
for index in self.data[depot.id_col]:
for col_name in columns:
columns[col_name].append(depot.data[col_name][index])
for col_name in columns:
values = np.array(columns[col_name], dtype=object)
self.data[col_name] = values
return self
def inject(self, depot: 'UniDep', col_names: Union[list, dict]):
if isinstance(col_names, list):
col_names = {col_name: col_name for col_name in col_names}
if depot.id_col not in self.cols:
raise ValueError(f'current depot has no column named {depot.id_col}')
columns = {col_names[col_name]: [] for col_name in col_names}
for index in self.data[depot.id_col]:
for col_name in columns:
columns[col_name].append(depot.data[col_name][index])
for col_name in columns:
self.set_col(
name=col_name,
values=columns[col_name],
vocab=depot.cols[col_name].voc.vocab,
)
def rename_col(self, old_name: str, new_name: str):
"""
rename a column
"""
if old_name not in self.cols:
raise ValueError(f'column {old_name} not found')
if new_name in self.cols:
raise ValueError(f'column {new_name} already exists')
if old_name is self.id_col:
self.meta.id_col = self.id_col = new_name
self.cols[new_name] = self.cols[old_name]
del self.cols[old_name]
self.data[new_name] = self.data[old_name]
del self.data[old_name]
def rename_vocab(self, old_name: str,
统一文本数据预处理工具.zip
版权申诉
56 浏览量
2024-03-02
21:59:28
上传
评论
收藏 22KB ZIP 举报
博士僧小星
- 粉丝: 1936
- 资源: 5894
最新资源
- 基于STM32的毕业设计项目可以涵盖多个领域和应用,以下是一个典型的基于STM32的毕业设计项目框架,并结合参考文章中的相关数字
- 对于端午节代码资源,你可以考虑以下几个方向: ### 1. 端午节主题的小游戏 你可以创建一个端午节主题的小游戏,比如"捞粽
- 如果你是在寻找编程相关的节日主题代码资源,我可以为你提供一些常见的做法和示例 以下是一些可能会有帮助的方法: ### 1. 制
- NX二次开发uc6496 函数介绍
- 在 MATLAB 中创建 GUI(图形用户界面)并进行仿真是一种常见的做法,特别是在需要交互式地探索数据或者模拟系统行为时 下面
- Vue.js 是一个流行的 JavaScript 框架,用于构建用户界面和单页面应用 它易于学习、灵活且功能强大,以下是一些 V
- NX二次开发uc6494 函数介绍
- 电力技术方案模板第二版
- 6.5图片/////////
- NX二次开发uc6483 函数介绍
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈