import numpy as np
import torch
import torch.nn.functional as F
from torch_struct import SemiMarkovCRF
import argparse
from sklearn.mixture import GaussianMixture
from scipy.optimize import linear_sum_assignment
class SemiMarkovModule(torch.nn.Module):
@classmethod
def add_args(cls, parser):
parser.add_argument('--sm_allow_self_transitions', action='store_true')
parser.add_argument('--sm_lr', type=float, default=1e-1)
parser.add_argument('--sm_supervised_state_smoothing', type=float, default=1e-2)
parser.add_argument('--sm_supervised_length_smoothing', type=float, default=1e-1)
parser.add_argument('--sm_supervised_cov_smoothing', type=float, default=0.)
parser.add_argument('--sm_cov_factor', type=float, default=1.)
def __init__(self, args, n_dims):
super(SemiMarkovModule, self).__init__()
self.args = args
self.n_classes = args.sm_n_classes
self.input_feature_dim = n_dims
self.feature_dim = n_dims
self.max_k = args.sm_max_k
self.allow_self_transitions = args.sm_allow_self_transitions
self.learning_rate = args.sm_lr
self.init_params()
def init_params(self):
"""Create torch differentiable params"""
poisson_log_rates = torch.zeros(self.n_classes, dtype=torch.float)
self.poisson_log_rates = torch.nn.Parameter(poisson_log_rates, requires_grad=True)
gaussian_means = torch.zeros(self.n_classes, self.feature_dim, dtype=torch.float)
self.gaussian_means = torch.nn.Parameter(gaussian_means, requires_grad=True)
gaussian_cov = torch.ones(self.n_classes, self.feature_dim, dtype=torch.float)
self.gaussian_cov = torch.nn.Parameter(gaussian_cov, requires_grad=False)
transition_logits = torch.zeros(self.n_classes, self.n_classes, dtype=torch.float)
self.transition_logits = torch.nn.Parameter(transition_logits, requires_grad=True)
init_logits = torch.zeros(self.n_classes, dtype=torch.float)
self.init_logits = torch.nn.Parameter(init_logits, requires_grad=True)
torch.nn.init.uniform_(self.init_logits, 0, 1)
def initialize_gaussian(self, data, lengths):
b, _, d = data.size()
assert lengths.size(0) == b
feats = []
for i in range(b):
feats.append(data[i, :lengths[i]])
feats = torch.cat(feats, dim=0)
assert d == self.feature_dim
mean = feats.mean(dim=0, keepdim=True).to(self.gaussian_means.device)
self.gaussian_means.data.zero_()
self.gaussian_means.data.add_(mean.expand((self.n_classes, self.feature_dim)))
self.gaussian_cov.data = torch.ones(self.n_classes, self.feature_dim, device=self.gaussian_cov.device) * self.args.sm_cov_factor
def initialize_supervised(self, feature_list, label_list, length_list, overrides=['mean', 'cov', 'init', 'trans', 'lengths'], freeze=True):
emission_gmm, stats = semimarkov_sufficient_stats(feature_list, label_list, length_list, covariance_type='diag', n_classes=self.n_classes, max_k=self.max_k)
if 'init' in overrides:
init_probs = (stats['span_start_counts'] + self.args.sm_supervised_state_smoothing) /\
float(stats['instance_count'] + self.args.sm_supervised_state_smoothing * self.n_classes)
init_probs[np.isnan(init_probs)] = 0
self.init_logits.data.zero_()
self.init_logits.data.add_(torch.from_numpy(init_probs).to(device=self.init_logits.device).log())
if freeze:
self.init_logits.requires_grad = False
if 'trans' in overrides:
smoothed_trans_counts = stats['span_transition_counts'] + self.args.sm_supervised_state_smoothing
trans_probs = smoothed_trans_counts / smoothed_trans_counts.sum(axis=0)[None, :]
trans_probs[np.isnan(trans_probs)] = 0
self.transition_logits.data.zero_()
self.transition_logits.data.add_(torch.from_numpy(trans_probs).to(device=self.transition_logits.device).log())
if freeze:
self.transition_logits.requires_grad = False
if 'lengths' in overrides:
mean_lengths = (stats['span_lengths'] + self.args.sm_supervised_length_smoothing) /\
(stats['span_counts'] + self.args.sm_supervised_length_smoothing)
self.poisson_log_rates.data.zero_()
self.poisson_log_rates.data.add_(torch.from_numpy(mean_lengths).to(device=self.poisson_log_rates.device).log())
if freeze:
self.poisson_log_rates.requires_grad = False
if 'mean' in overrides:
self.gaussian_means.data.zero_()
self.gaussian_means.data.add_(torch.from_numpy(emission_gmm.means_).to(device=self.gaussian_means.device, dtype=torch.float))
if freeze:
self.gaussian_means.requires_grad = False
if 'cov' in overrides:
self.gaussian_cov.data.zero_()
self.gaussian_cov.data.add_(torch.from_numpy(emission_gmm.covariances_ + 1e-3).to(device=self.gaussian_cov.device, dtype=torch.float))
self.gaussian_cov.data.add_(torch.full(self.gaussian_cov.size(), self.args.sm_supervised_cov_smoothing).to(device=self.gaussian_cov.device, dtype=torch.float))
def fit_supervised(self, feature_list, label_list, length_list):
emission_gmm, stats = semimarkov_sufficient_stats(feature_list, label_list, length_list, covariance_type='diag', n_classes=self.n_classes, max_k=self.max_k)
init_probs = (stats['span_start_counts'] + self.args.sm_supervised_state_smoothing) /\
float(stats['instance_count'] + self.args.sm_supervised_state_smoothing * self.n_classes)
init_probs[np.isnan(init_probs)] = 0
self.init_logits.data.zero_()
self.init_logits.data.add_(torch.from_numpy(init_probs).to(device=self.init_logits.device).log())
smoothed_trans_counts = stats['span_transition_counts'] + self.args.sm_supervised_state_smoothing
trans_probs = smoothed_trans_counts / smoothed_trans_counts.sum(axis=0)[None, :]
trans_probs[np.isnan(trans_probs)] = 0
self.transition_logits.data.zero_()
self.transition_logits.data.add_(torch.from_numpy(trans_probs).to(device=self.transition_logits.device).log())
mean_lengths = (stats['span_lengths'] + self.args.sm_supervised_length_smoothing) /\
(stats['span_counts'] + self.args.sm_supervised_length_smoothing)
self.poisson_log_rates.data.zero_()
self.poisson_log_rates.data.add_(torch.from_numpy(mean_lengths).to(device=self.poisson_log_rates.device).log())
self.gaussian_means.data.zero_()
self.gaussian_means.data.add_(torch.from_numpy(emission_gmm.means_).to(device=self.gaussian_means.device, dtype=torch.float))
self.gaussian_cov.data.zero_()
self.gaussian_cov.data.add_(torch.from_numpy(emission_gmm.covariances_ + 1e-3).to(device=self.gaussian_cov.device, dtype=torch.float))
self.gaussian_cov.data.add_(torch.full(self.gaussian_cov.size(), self.args.sm_supervised_cov_smoothing).to(device=self.gaussian_cov.device, dtype=torch.float))
def transition_log_probs(self, valid_classes):
"""Mask out invalid classes and apply softmax to transition logits"""
transition_logits = self.transition_logits
if valid_classes is not None:
transition_logits = transition_logits[valid_classes][:, valid_classes]
n_classes = len(valid_classes)
else:
n_classes = self.n_classes
if self.allow_self_transitions:
masked = transition_logits
else:
masked = transition_logits.masked_fill(torch.eye(n_classes, device=self.transition_logits.device, dtype=bool), -1e9)
return F.log_softmax(masked, dim=0)
def emission_log_probs(self, features, valid_classes):
"""Compute likelihood of emissions for each class"""
评论0