import torch
import torch.nn as nn
import numpy as np
from .KANLayer import KANLayer
#from .Symbolic_MultKANLayer import *
from .Symbolic_KANLayer import Symbolic_KANLayer
from .LBFGS import *
import os
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import copy
#from .MultKANLayer import MultKANLayer
import pandas as pd
from sympy.printing import latex
from sympy import *
import sympy
import yaml
from .spline import curve2coef
from .utils import SYMBOLIC_LIB
from .hypothesis import plot_tree
from nequip.nn import (
GraphModuleMixin,
InteractionBlock,
)
class MultKAN(GraphModuleMixin,nn.Module):
'''
KAN class
Attributes:
-----------
grid : int
the number of grid intervals
k : int
spline order
act_fun : a list of KANLayers
symbolic_fun: a list of Symbolic_KANLayer
depth : int
depth of KAN
width : list
number of neurons in each layer.
Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).
mult_arity : int, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
grid : int
the number of grid intervals
k : int
the order of piecewise polynomial
base_fun : fun
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
symbolic_fun : a list of Symbolic_KANLayer
Symbolic_KANLayers
symbolic_enabled : bool
If False, the symbolic front is not computed (to save time). Default: True.
width_in : list
The number of input neurons for each layer
width_out : list
The number of output neurons for each layer
base_fun_name : str
The base function b(x)
grip_eps : float
The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
node_bias : a list of 1D torch.float
node_scale : a list of 1D torch.float
subnode_bias : a list of 1D torch.float
subnode_scale : a list of 1D torch.float
symbolic_enabled : bool
when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
affine_trainable : bool
indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
sp_trainable : bool
indicate whether the overall magnitude of splines is trainable
sb_trainable : bool
indicate whether the overall magnitude of base function is trainable
save_act : bool
indicate whether intermediate activations are saved in forward pass
node_scores : None or list of 1D torch.float
node attribution score
edge_scores : None or list of 2D torch.float
edge attribution score
subnode_scores : None or list of 1D torch.float
subnode attribution score
cache_data : None or 2D torch.float
cached input data
acts : None or a list of 2D torch.float
activations on nodes
auto_save : bool
indicate whether to automatically save a checkpoint once the model is modified
state_id : int
the state of the model (used to save checkpoint)
ckpt_path : str
the folder to store checkpoints
round : int
the number of times rewind() has been called
device : str
'''
def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
'''
initalize a KAN model
Args:
-----
width : list of int
Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
grid : int
number of grid intervals. Default: 3.
k : int
order of piecewise polynomial. Default: 3.
mult_arity : int, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
noise_scale : float
initial injected noise to spline.
base_fun : str
the residual function b(x). Default: 'silu'
symbolic_enabled : bool
compute (True) or skip (False) symbolic computations (for efficiency). By default: True.
affine_trainable : bool
affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
grid_eps : float
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
grid_range : list/np.array of shape (2,))
setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
sp_trainable : bool
If true, scale_sp is trainable. Default: True.
sb_trainable : bool
If true, scale_base is trainable. Default: True.
device : str
device
seed : int
random seed
save_act : bool
indicate whether intermediate activations are saved in forward pass
sparse_init : bool
sparse initialization (True) or normal dense initialization. Default: False.
auto_save : bool
indicate whether to automatically save a checkpoint once the model is modified
state_id : int
the state of the model (used to save checkpoint)
ckpt_path : str
the folder to store checkpoints. Default: './model'
round : int
the number of times rewind() has been called
device : str
Returns:
--------
self
Example
-------
>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
checkpoint directory created: ./model
saving model version 0.0
'''
# super(MultKAN, self).__init__()
super().__init__()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
### initializeing the numerical front ###
self.act_fun = []
self.depth = len(width) - 1
for i in range(len(width)):
if type(width[i]) == int:
width[i] = [width[i],0]
self.width = width
# if mult_arity is just a scalar, we extend it to a list of lists
# e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;
# in the second hidden layer, 1 mult op has arity 4.
if isinstance(mult_arity, int):
self.mult_homo = True # when homo is True, parallelization is possible
else:
self.mult_homo = False # when home if False,
没有合适的资源?快使用搜索试试~ 我知道了~
nequip模型代码V1
需积分: 1 0 下载量 91 浏览量
2024-10-21
18:01:45
上传
评论
收藏 244.97MB GZ 举报
温馨提示
nequip模型代码V1
资源推荐
资源详情
资源评论
收起资源包目录
nequip模型代码V1 (2000个子文件)
make.bat 795B
custom.css 879B
metrics_batch_val.csv 0B
metrics_epoch.csv 0B
metrics_initialization.csv 0B
metrics_batch_train.csv 0B
metrics_initialization.csv 0B
metrics_epoch.csv 0B
metrics_batch_val.csv 0B
metrics_batch_train.csv 0B
nequip-0.6.1-py3.9.egg 683KB
.flake8 156B
experiment1.ipynb 233KB
experiment1.ipynb 233KB
wandb-summary.json 1KB
wandb-summary.json 1KB
wandb-summary.json 1KB
wandb-summary.json 1KB
wandb-summary.json 1KB
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
wandb-metadata.json 704B
共 2000 条
- 1
- 2
- 3
- 4
- 5
- 6
- 20
资源评论
JZJQuest
- 粉丝: 292
- 资源: 10
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功