#include<linux/kernel.h>
#include<linux/init.h>
#include<linux/module.h>
#include<linux/version.h>
#include<linux/skbuff.h>
#include <linux/net.h>
#include<linux/netfilter.h>
#include<linux/netfilter_ipv4.h>
#include <linux/netfilter_bridge.h>
#include <linux/netdevice.h>
#include <linux/init.h>
#include <linux/stat.h>
#include <linux/slab.h>
#include <net/sock.h>
#include <net/ip.h>
#include <linux/spinlock.h>
#include <linux/socket.h>
#include <linux/string.h>
#include <linux/kdev_t.h>
#include <linux/kmod.h>
#include <linux/fs.h>
#include <linux/device.h>
#include <linux/cdev.h>
#include <asm/uaccess.h>
#include <asm/unistd.h>
#include <net/netlink.h>
#include <linux/tcp.h>
#include <linux/ip.h>
#include <linux/icmp.h>
#include <linux/udp.h>
#include <linux/in.h>
#include <linux/jiffies.h>
#include <linux/time.h>
#include <linux/timex.h>
#include <linux/timer.h>
#include <linux/vmalloc.h>
#include <linux/workqueue.h>
#include <linux/if_arp.h>
#include <linux/rtc.h>
#include <linux/if_ether.h>
#include <linux/types.h>
#include <linux/proc_fs.h>
#include <linux/vmalloc.h>
#include <linux/workqueue.h>
#include <linux/spinlock.h>
#include <linux/types.h>
#include <linux/proc_fs.h>
#define NETLINK_TEST (25)
#define TCP 6
#define UDP 17
#define ICMP 1
#define ANY -1
#define ICMP_PORT 30001
#define MAX_RULE_NUM 50
#define MAX_STATU_NUM 101
#define MAX_NAT_NUM 100
#define MAX_LOG_NUM 100
MODULE_LICENSE("GPL");
//MODULE_AUTHER("FGY");
//struct { __u32 pid; }user_process;
typedef struct {
char src_ip[20];
char dst_ip[20];
int src_port;
int dst_port;
char protocol;
bool action;
bool log;
}Rule;
static Rule rules[MAX_RULE_NUM];
static int rnum = 0; //rules num
typedef struct {
unsigned src_ip;
unsigned short src_port;
unsigned dst_ip;
unsigned short dst_port;
unsigned char protocol;
unsigned long t;
}Connection;
static Connection cons[MAX_STATU_NUM];
static Connection cons2[MAX_STATU_NUM];
static int cnum = 0; //nat rules num
typedef struct {
//unsigned firewall_ip;
unsigned nat_ip;
unsigned short firewall_port;
unsigned short nat_port;
}NatEntry;
static NatEntry natTable[MAX_NAT_NUM];
static int nnum = 0; //nat rules num
unsigned net_ip, net_mask, firewall_ip;
unsigned short firewall_port = 20000;
typedef struct {
unsigned src_ip;
unsigned dst_ip;
unsigned short src_port;
unsigned short dst_port;
unsigned char protocol;
unsigned char action;
}Log;
static Log logs[MAX_LOG_NUM];
static int lnum = 0;//logs num
static void netlink_input(struct sk_buff *__skb);
unsigned int hook_func(unsigned int,struct sk_buff *,const struct net_device *,
const struct net_device *,int(*okfn)(struct sk_buff*));
unsigned int hook_func_nat_in(unsigned int,struct sk_buff *,const struct net_device *,
const struct net_device *,int(*okfn)(struct sk_buff*));
unsigned int hook_func_nat_out(unsigned int,struct sk_buff *,const struct net_device *,
const struct net_device *,int(*okfn)(struct sk_buff*));
static dev_t devId;
static struct class *cls = NULL;
struct sock *nl_sk = NULL;
struct netlink_kernel_cfg nkc = {
.groups = 0,
.flags = 0,
.input = netlink_input,
.cb_mutex = NULL,
.bind = NULL,
//nkc.unbind = NULL;
.compare = NULL
};
static struct nf_hook_ops input_filter = {
.hook = (nf_hookfn *)hook_func,
.owner = THIS_MODULE,
.pf = PF_INET,
.hooknum = NF_INET_PRE_ROUTING,
.priority = NF_IP_PRI_FIRST
}; // NF_INET_PRE_ROUTING - for incoming packets
static struct nf_hook_ops output_filter = {
.hook = (nf_hookfn *)hook_func,
.owner = THIS_MODULE,
.pf = PF_INET,
.hooknum = NF_INET_POST_ROUTING,
.priority = NF_IP_PRI_FIRST
}; // NF_INET_POST_ROUTING - for outgoing packets
static struct nf_hook_ops input_nat_filter = {
.hook = (nf_hookfn *)hook_func_nat_in,
.owner = THIS_MODULE,
.pf = PF_INET,
.hooknum = NF_INET_PRE_ROUTING,
.priority = NF_IP_PRI_NAT_DST
};
static struct nf_hook_ops output_nat_filter = {
.hook = (nf_hookfn *)hook_func_nat_out,
.owner = THIS_MODULE,
.pf = PF_INET,
.hooknum = NF_INET_POST_ROUTING,
.priority = NF_IP_PRI_NAT_SRC
};
/*-----------------------------------hash begin---------------------------------------------*/
//return 102: no empty places
//return 101: find exist connection
//return 0<=p<101: find places to insert
unsigned CHashCheck(unsigned src_ip, unsigned dst_ip, unsigned char protocol,
unsigned short src_port, unsigned short dst_port){
unsigned p = (src_ip ^ dst_ip ^ protocol ^ src_port ^ src_port) % 101;
unsigned tmp = p;
while(time_before(jiffies, cons[p].t)){
if((protocol == cons[p].protocol && src_ip == cons[p].src_ip && dst_ip == cons[p].dst_ip
&& src_port == cons[p].src_port && dst_port == cons[p].dst_port) ||
(protocol == cons[p].protocol && dst_ip == cons[p].src_ip && src_ip == cons[p].dst_ip
&& dst_port == cons[p].src_port && src_port == cons[p].dst_port)){
cons[p].t = jiffies + 10 * HZ;
printk("hash check return 101 exist\n");
return 101;
}
p = (p + 7) % 101;
if(p == tmp){
printk("hash check return 102 full\n");
return 102;
}
}
printk("hash check return p:%u\n", p);
return p;
}
void UpdateHashList(void){
int i;
cnum = 0;
for(i = 0; i < MAX_STATU_NUM; i++){
if(time_before(jiffies, cons[i].t)){
cons2[cnum].src_ip = cons[i].src_ip;
cons2[cnum].dst_ip = cons[i].dst_ip;
cons2[cnum].src_port = cons[i].src_port;
cons2[cnum].dst_port = cons[i].dst_port;
cons2[cnum].protocol = cons[i].protocol;
cons2[cnum].t = 0;
cnum++;
}
}
}
void CHashInsert(unsigned src_ip, unsigned dst_ip, unsigned char protocol,
unsigned short src_port, unsigned short dst_port, unsigned p){
cons[p].src_ip = src_ip;
cons[p].dst_ip = dst_ip;
cons[p].src_port = src_port;
cons[p].dst_port = dst_port;
cons[p].protocol = protocol;
cons[p].t = jiffies + 10 * HZ;
}
/*-----------------------------------tools begin------------------------------------------*/
bool IsMatch(unsigned ip, const char *ip_range){
char tmp_ip[20];
int p = -1, count = 0;
unsigned len = 0, tmp = 0, mask = 0, r_ip = 0,i;
strcpy(tmp_ip, ip_range);
for(i = 0; i < strlen(tmp_ip); i++){
if(p != -1){
len *= 10;
len += tmp_ip[i] - '0';
}
else if(tmp_ip[i] == '/')
p = i;
}
if(p != -1){
tmp_ip[p] = '\0';
if(len)
mask = 0xFFFFFFFF << (32 - len);
}
else mask = 0xFFFFFFFF;
for(i = 0; i < strlen(tmp_ip); i++){
if(tmp_ip[i] == '.'){
r_ip = r_ip | (tmp << (8 * (3 - count)));
tmp = 0;
count++;
continue;
}
tmp *= 10;
tmp += tmp_ip[i] - '0';
}
r_ip = r_ip | tmp;
return (r_ip & mask) == (ip & mask);
}
unsigned ipstr_to_num(const char *ip_str){
int count = 0;
unsigned tmp = 0,ip = 0, i;
for(i = 0; i < strlen(ip_str); i++){
if(ip_str[i] == '.'){
ip = ip | (tmp << (8 * (3 - count)));
tmp = 0;
count++;
continue;
}
tmp *= 10;
tmp += ip_str[i] - '0';
}
ip = ip | tmp;
return ip;
}
char * addr_from_net(char * buff, __be32 addr){
__u8 *p = (__u8*)&addr;
snprintf(buff, 16, "%u.%u.%u.%u",
(__u32)p[0], (__u32)p[1], (__u32)p[2], (__u32)p[3]);
return buff;
}
void print_ip(unsigned long ip) {
printk("%ld.%ld.%ld.%ld\n", (ip>>24)&0xff, (ip>>16)&0xff, (ip>>8)&0xff, (ip>>0)&0xff);
}
/*------------------------------------netlink begin--------------------------------------------------*/
static void netlink_cleanup(void)
{
netlink_kernel_release(nl_sk);
device_destroy(cls, devId);
class_destroy(cls);
unregister_chrdev_region(devId, 1);
}
static void netlink_send(int pid, uint8_t *message, int len)
{
struct sk_buff *skb_1;
struct nlmsghdr *nlh;
if(!message || !nl_sk) {
return;
}
skb_1 = alloc_skb(NLMSG_SPACE(len), GFP_KERNEL);
if( !skb_1 ) {