/* 
 * accounting match (ipt_account.c)
 * (C) 2003,2004 by Piotr Gasido (quaker@barbara.eu.org)
 *
 * Version: 0.1.5
 *
 * This software is distributed under the terms of GNU GPL
 */

#include <linux/module.h>
#include <linux/skbuff.h>
#include <linux/proc_fs.h>
#include <linux/spinlock.h>
#include <linux/vmalloc.h>
#include <linux/interrupt.h>

#include <asm/uaccess.h>

#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/udp.h>

#include <linux/netfilter_ipv4/ip_tables.h>
#include <linux/netfilter_ipv4/ipt_account.h>

static char version[] =
KERN_INFO "ipt_account 0.1.5 : Piotr Gasido <quaker@barbara.eu.org>, http://www.barbara.eu.org/~quaker/ipt_account/\n";

/* default rights for files created in /proc/net/ipt_account/ */
static int ip_list_perms = 0644;

/* 
 * safe netmask, if you want account traffic for networks 
 * bigger that /17 you must specify ip_list_max_hosts parameter 
 * during load 
 */
static int ip_list_max_mask = 17;
static int ip_list_max_hosts_count;
static int debug = 0;

/* module information */
MODULE_AUTHOR("Piotr Gasido <quaker@barbara.eu.org>");
MODULE_DESCRIPTION("Traffic accounting modules");
MODULE_LICENSE("GPL");
MODULE_PARM(ip_list_perms,"i");
MODULE_PARM_DESC(ip_list_perms,"permissions on /proc/net/ipt_account/* files");
MODULE_PARM(ip_list_max_mask, "i");
MODULE_PARM_DESC(ip_list_max_mask, "maximum *save* size of one list (netmask)");
#ifdef DEBUG
MODULE_PARM(debug,"i");
MODULE_PARM_DESC(debug,"debugging level, defaults to 0");
#endif

/* structure with statistics counters */
struct t_ipt_account_stat {
	u_int64_t b_all, b_tcp, b_udp, b_icmp, b_other;		/* byte counters for all/tcp/udp/icmp/other traffic  */
	u_int64_t p_all, p_tcp, p_udp, p_icmp, p_other;		/* packet counters for all/tcp/udp/icmp/other traffic */
};
 
/* structure holding to/from statistics for single ip */
struct t_ipt_account_ip_list {
	struct t_ipt_account_stat src;
	struct t_ipt_account_stat dest;
};

/* structure describing single table */
struct t_ipt_account_table {
	char name[IPT_ACCOUNT_NAME_LEN];	/* table name ( = filename in /proc/net/ipt_account/) */
	struct t_ipt_account_ip_list *ip_list;	/* table with statistics for each ip in network/netmask */
	struct t_ipt_account_table *next;
	u_int32_t network;			/* network/netmask covered by table*/
	u_int32_t netmask;					
	int use_count;				/* rules counter - counting number of rules using this table */
	spinlock_t ip_list_lock;
	struct proc_dir_entry *status_proc_64, *status_proc_32;
};

/* we must use spinlocks to avoid parallel modifications of table list */
static spinlock_t ipt_account_tables_lock = SPIN_LOCK_UNLOCKED;

static struct proc_dir_entry *proc_net_ipt_account = NULL;

/* root pointer holding list of the tables */
static struct t_ipt_account_table *ipt_account_tables = NULL;

static int ip_list_read_proc_64(char *buffer, char **start, off_t offset,
		int length, int *eof, void *data) {
	
	int len = 0, last_len = 0;
	off_t pos = 0, begin = 0;

	u_int32_t address, index;	

	struct t_ipt_account_table *table = (struct t_ipt_account_table*)data;

	spin_lock(&table->ip_list_lock);
	for (address = table->network; (u_int32_t)(address & table->netmask) == (u_int32_t)(table->network); address++) {
		last_len = len;		
		index = address - table->network;
			len += sprintf(buffer + len,
				"ip = %u.%u.%u.%u bytes_src = %llu %llu %llu %llu %llu packets_src = %llu %llu %llu %llu %llu bytes_dest = %llu %llu %llu %llu %llu packets_dest = %llu %llu %llu %llu %llu\n",
				HIPQUAD(address),
				table->ip_list[index].src.b_all,
				table->ip_list[index].src.b_tcp,
				table->ip_list[index].src.b_udp,
				table->ip_list[index].src.b_icmp,
				table->ip_list[index].src.b_other,
				
				table->ip_list[index].src.p_all,
				table->ip_list[index].src.p_tcp,
				table->ip_list[index].src.p_udp,
				table->ip_list[index].src.p_icmp,
				table->ip_list[index].src.p_other,
				
				table->ip_list[index].dest.b_all,
				table->ip_list[index].dest.b_tcp,
				table->ip_list[index].dest.b_udp,
				table->ip_list[index].dest.b_icmp,
				table->ip_list[index].dest.b_other,				
				
				table->ip_list[index].dest.p_all,
				table->ip_list[index].dest.p_tcp,
				table->ip_list[index].dest.p_udp,
				table->ip_list[index].dest.p_icmp,
				table->ip_list[index].dest.p_other
				);
		pos = begin + len;
		if (pos < offset) {
			len = 0;
			begin = pos;
		}
		if (pos > offset + length) {
			len = last_len;
			break;
		}
	}
	spin_unlock(&table->ip_list_lock);
	*start = buffer + (offset - begin);
	len -= (offset - begin);
	if (len > length)
		len = length;
	return len;
}

static int ip_list_read_proc_32(char *buffer, char **start, off_t offset,
		int length, int *eof, void *data) {
	
	int len = 0, last_len = 0;
	off_t pos = 0, begin = 0;

	u_int32_t address, index;

	struct t_ipt_account_table *table = (struct t_ipt_account_table*)data;
	
	spin_lock(&table->ip_list_lock);
	for (address = table->network; (u_int32_t)(address & table->netmask) == (u_int32_t)(table->network); address++) {
		last_len = len;	
		index = address - table->network;
		len += sprintf(buffer + len,
				"ip = %u.%u.%u.%u bytes_src = %u %u %u %u %u packets_src = %u %u %u %u %u bytes_dest = %u %u %u %u %u packets_dest = %u %u %u %u %u\n",
				HIPQUAD(address),
				(u_int32_t)table->ip_list[index].src.b_all,
				(u_int32_t)table->ip_list[index].src.b_tcp,
				(u_int32_t)table->ip_list[index].src.b_udp,
				(u_int32_t)table->ip_list[index].src.b_icmp,
				(u_int32_t)table->ip_list[index].src.b_other,
				
				(u_int32_t)table->ip_list[index].src.p_all,
				(u_int32_t)table->ip_list[index].src.p_tcp,
				(u_int32_t)table->ip_list[index].src.p_udp,
				(u_int32_t)table->ip_list[index].src.p_icmp,
				(u_int32_t)table->ip_list[index].src.p_other,
				
				(u_int32_t)table->ip_list[index].dest.b_all,
				(u_int32_t)table->ip_list[index].dest.b_tcp,
				(u_int32_t)table->ip_list[index].dest.b_udp,
				(u_int32_t)table->ip_list[index].dest.b_icmp,
				(u_int32_t)table->ip_list[index].dest.b_other,
				
				(u_int32_t)table->ip_list[index].dest.p_all,
				(u_int32_t)table->ip_list[index].dest.p_tcp,
				(u_int32_t)table->ip_list[index].dest.p_udp,
				(u_int32_t)table->ip_list[index].dest.p_icmp,
				(u_int32_t)table->ip_list[index].dest.p_other
		
			);
		pos = begin + len;
		if (pos < offset) {
			len = 0;
			begin = pos;
		}
		if (pos > offset + length) {
			len = last_len;
			break;
		}		
	}
	spin_unlock(&table->ip_list_lock);				
	*start = buffer + (offset - begin);
	len -= (offset - begin);
	if (len > length)
		len = length;
	return len;
}

static int ip_list_write_proc(struct file *file, const char *buffer, 
		unsigned long length, void *data) {

	int len = (length > 1024) ? length : 1024;
	struct t_ipt_account_table *table = (struct t_ipt_account_table*)data;
	char kernel_buffer[1024];
	u_int32_t hosts_count = INADDR_BROADCAST - table->netmask + 1;
	
	copy_from_user(kernel_buffer, buffer, len);
	kernel_buffer[len - 1] = 0;
	
	/* echo "reset" > /proc/net/ipt_recent/table clears the table */
	if (!strncmp(kernel_buffer, "reset", len)) {
		spin_lock(&table->ip_list_lock);
		memset(table->ip_list, 0, sizeof(struct t_ipt_account_ip_list) * hosts_count);
		spin_unlock(&table->ip_list_lock);
	}
	
	return len;
}

/* do raw accounting */
static void do_account(struct t_ipt_account_stat *stat, u_int8_t proto, u_int16_t pktlen) {
	
	/* update packet & bytes counters in *stat structure */
	stat->b_all += pktlen;
	stat->p_all++;
	
	switch (proto) {
		case IPPROTO_TCP:
			stat->b_tcp += pktlen;
			stat->p_tcp++;
			break;
		case IPPROTO_UDP:
			stat->b_udp += pktlen;
			stat->p_udp++;
			break;
		case IPPROTO_ICMP:
			stat->b_icmp += pktlen;
			stat->p_icmp++;
			break;
		default:
			stat->b_other += pktlen;
			stat->p_other++;
	}
}

static int match(const struct sk_buff *skb,
	  const struct net_device *in,
	  const struct net_device *out,
	  const void *matchinfo,
	  int offset,
	  const void *hdr,
	  u_int16_t datalen,
	  int *hotdrop)
{
	
	const struct t_ipt_account_info *info = (struct t_ipt_account_info*)matchinfo;
	struct t_ipt_account_table *table;
	int ret;

	u_int32_t address;
	u_int16_t pktlen;
	u_int8_t proto;
	
	if (debug) {
		printk(KERN_INFO "ipt_account: match() entering.\n");
		printk(KERN_INFO "ipt_account: match() match name = %s.\n", info->name);
	}
	
	spin_lock(&ipt_account_tables_lock);
	/* find the right table */
	table = ipt_account_tables;
	while (table && strncmp(table->name, info->name, IPT_ACCOUNT_NAME_LEN) && (table = table->next));
	spin_unlock(&ipt_account_tables_lock);

	if (table == NULL) {
		/* ups, no table with that name */
		if (debug)
			printk(KERN_INFO "ipt_account: match() table %s not found. Leaving.\n", info->name);
		return 0;
	}

	if (debug)
		printk(KERN_INFO "ipt_account: match() table found %s\n", table->name);

	/* default: no match */
	ret = 0;

	/* get packet protocol/length */
	pktlen = skb->len;
	proto = skb->nh.iph->protocol;

	if (debug)
		printk(KERN_INFO "ipt_account: match() got packet src = %u.%u.%u.%u, dst = %u.%u.%u.%u, proto = %u.\n",
				 NIPQUAD(skb->nh.iph->saddr),
				 NIPQUAD(skb->nh.iph->daddr),
				 proto
				 );

	/* check whether traffic from source ip address ... */
	address = ntohl(skb->nh.iph->saddr);
	
	/* ... is being accounted by this table */	
	if (address && ((u_int32_t)(address & table->netmask) == (u_int32_t)table->network)) {		
		if (debug)
			printk(KERN_INFO "ipt_account: match() accounting packet src = %u.%u.%u.%u, proto = %u.\n",
					 HIPQUAD(address),
					 proto
					 );
		/* yes, account this packet */
		spin_lock(&table->ip_list_lock);
		/* update counters this host */
		do_account(&table->ip_list[(u_int32_t)(address - table->network)].src, proto, pktlen);
		/* update counters for all hosts in this table (network address) */
		if (table->netmask != INADDR_BROADCAST)
			do_account(&table->ip_list[0].src, proto, pktlen);
		spin_unlock(&table->ip_list_lock);
		/* yes, it's a match */
		ret = 1;
	}

	/* do the same thing with destination ip address */
	address = ntohl(skb->nh.iph->daddr);
    if (address && ((u_int32_t)(address & table->netmask) == (u_int32_t)table->network)) {
		if (debug)
			printk(KERN_INFO "ipt_account: match() accounting packet dst = %u.%u.%u.%u, proto = %u.\n",
					HIPQUAD(address),
					proto
					);
		spin_lock(&table->ip_list_lock);
		do_account(&table->ip_list[(u_int32_t)(address - table->network)].dest, proto, pktlen);
		if (table->netmask != INADDR_BROADCAST)
			do_account(&table->ip_list[0].dest, proto, pktlen);
		spin_unlock(&table->ip_list_lock);
		ret = 1;
	
	}

	if (debug)
		printk(KERN_INFO "ipt_account: match() leaving.\n");	
	
	return ret;
	
}

static int checkentry(const char *tablename,
	       const struct ipt_ip *ip,
	       void *matchinfo,
	       unsigned int matchinfosize,
	       unsigned int hook_mask)
{
	const struct t_ipt_account_info *info = matchinfo;
	struct t_ipt_account_table *table;

	char proc_entry_name[IPT_ACCOUNT_NAME_LEN + 3];

	u_int32_t hosts_count;

	if (debug)
		printk(KERN_INFO "ipt_account: checkentry() entering.\n");

	if (matchinfosize != IPT_ALIGN(sizeof(struct t_ipt_account_info)))
		return 0;
	
	if (!info->name || !info->name[0])
		return 0;

	/* find whether table with this name already exists */
	spin_lock(&ipt_account_tables_lock);
	table = ipt_account_tables;
	while (table && strncmp(info->name, table->name, IPT_ACCOUNT_NAME_LEN) && (table = table->next));
	
	if (table) {
		/* yes, table exists */
		if (info->network != table->network || info->netmask != table->netmask) {
			/* 
			 * tried to do accounting in existing table, but network/netmask in iptable rule
			 * doesn't match network/netmask in table structure - deny adding the rule 
			 */
			printk(KERN_ERR "ipt_account: checkentry() table %s found. But table netmask/network %u.%u.%u.%u/%u.%u.%u.%u differs from rule netmask/network %u.%u.%u.%u/%u.%u.%u.%u. Leaving without creating entry.\n", 
					table->name,
					HIPQUAD(table->network),
					HIPQUAD(table->netmask),
					HIPQUAD(info->network),
					HIPQUAD(info->netmask)
					);
			spin_unlock(&ipt_account_tables_lock);
			return 0;
		}			
		if (debug)
			printk(KERN_INFO "ipt_account: checkentry() table %s found. Incrementing use count (use_count = %i). Leaving.\n", table->name, table->use_count);
		/* increase table use count */		
		table->use_count++;
		spin_unlock(&ipt_account_tables_lock);
		/* everything went okey */
		return 1;
	};
	
	if (debug)
		printk(KERN_INFO "ipt_account: checkentry() table %s not found. Creating.\n", info->name);
	
	/* table doesn't exist - create one */
	table = vmalloc(sizeof(struct t_ipt_account_table));
	if (table == NULL) {
	    if (debug)
		    printk(KERN_INFO "ipt_account: checkentry() unable to allocate memory (t_account_table) for table %s. Leaving.\n", info->name);
	    spin_unlock(&ipt_account_tables_lock);
	    return -ENOMEM;
	}
	
	/* set table parameters */
	strncpy(table->name, info->name, IPT_ACCOUNT_NAME_LEN);
	table->use_count = 1;	
	table->network = info->network;
	table->netmask = info->netmask;
	table->ip_list_lock = SPIN_LOCK_UNLOCKED;
	
	hosts_count = INADDR_BROADCAST - table->netmask + 1;
	
	if (debug)
		printk(KERN_INFO "ipt_account: checkentry() allocating memory for %u hosts (%u netmask).\n", hosts_count, info->netmask);

	/* check whether table is not too big */
	if (hosts_count > ip_list_max_hosts_count) {
		printk(KERN_ERR "ipt_account: checkentry() unable allocate memory for %u hosts (%u netmask). Increase value of ip_list_max_mask parameter.\n", hosts_count, info->netmask);
		vfree(table);
		spin_unlock(&ipt_account_tables_lock);
		return -ENOMEM;
	}

	table->ip_list = vmalloc(sizeof(struct t_ipt_account_ip_list) * hosts_count);
	if (table->ip_list == NULL) {
	    if (debug)
		    printk(KERN_INFO "ipt_account: checkentry() unable to allocate memory (t_account_ip_list) for table %s. Leaving.\n", table->name);
	    vfree(table);
	    spin_unlock(&ipt_account_tables_lock);
	    return -ENOMEM;
	}

	memset(table->ip_list, 0, sizeof(struct t_ipt_account_ip_list) * hosts_count);

	/* 
	 * create entries in /proc/net/ipt_account: one with full 64-bit counters and
	 * second with 32-bit ones. The second can be used in programs supporting only 32-bit numbers
	 * (mrtg, rrdtool).
	 */
	
	strncpy(proc_entry_name, table->name, IPT_ACCOUNT_NAME_LEN);
	strncat(proc_entry_name, "_64", 4);
	
	table->status_proc_64 = create_proc_entry(proc_entry_name, ip_list_perms, proc_net_ipt_account);
	if (table->status_proc_64 == NULL) {	    
	    if (debug)
		    printk(KERN_INFO "ipt_account: checkentry() unable to allocate memory (status_proc_64) for table %s. Leaving.\n", table->name);
	    vfree(table->ip_list);
	    vfree(table);
	    spin_unlock(&ipt_account_tables_lock);
	    return -ENOMEM;
	}

	table->status_proc_64->owner = THIS_MODULE;
	table->status_proc_64->read_proc = ip_list_read_proc_64;
	table->status_proc_64->write_proc = ip_list_write_proc;
	table->status_proc_64->data = table;	

	strncpy(proc_entry_name, table->name, IPT_ACCOUNT_NAME_LEN);
	strncat(proc_entry_name, "_32", 4);
	
	table->status_proc_32 = create_proc_entry(proc_entry_name, ip_list_perms, proc_net_ipt_account);
	if (table->status_proc_32 == NULL) {	    
	    if (debug)
		    printk(KERN_INFO "ipt_account: checkentry() unable to allocate memory (status_proc_32) for table %s. Leaving.\n", table->name);
	    vfree(table->ip_list);
	    vfree(table);
	    spin_unlock(&ipt_account_tables_lock);
	    return -ENOMEM;
	}

	table->status_proc_32->owner = THIS_MODULE;
	table->status_proc_32->read_proc = ip_list_read_proc_32;
	table->status_proc_32->write_proc = ip_list_write_proc;
	table->status_proc_32->data = table;
	
	/* finaly, insert table into list */
	table->next = ipt_account_tables;
	ipt_account_tables = table;

	if (debug)
		printk(KERN_INFO "ipt_account: checkentry() successfully created table %s (use_count = %i).\n", table->name, table->use_count);

	spin_unlock(&ipt_account_tables_lock);
	
	if (debug)
		printk(KERN_INFO "ipt_account: checkentry() leaving.\n");
	return 1;
}

static void destroy(void *matchinfo, 
	     unsigned int matchinfosize)
{
	const struct t_ipt_account_info *info = matchinfo;
	struct t_ipt_account_table *table, *last_table;
	char proc_entry_name[IPT_ACCOUNT_NAME_LEN + 3];
	
	if (debug)
		printk(KERN_INFO "ipt_account: destroy() entered.\n");
	
	if (matchinfosize != IPT_ALIGN(sizeof(struct t_ipt_account_info)))
		return;

	spin_lock(&ipt_account_tables_lock);
	table = ipt_account_tables;

	if (table == NULL) {
		/* list is empty, sometheing is realy wrong! */
		if (debug)
			printk(KERN_INFO "ipt_account: destroy() unable to found any tables (asked for %s). Leaving.\n", info->name);
		spin_unlock(&ipt_account_tables_lock);
		return;
	}
	
	/* find table combined with this rule - this code is taken for ipt_recent ;) */
	last_table = NULL;
	while (strncmp(table->name, info->name, IPT_ACCOUNT_NAME_LEN) && (last_table = table) && (table = table->next));

	if (table == NULL) {
		printk(KERN_ERR "ipt_account: destroy() unable to found table %s. Leaving.\n", info->name);
		spin_unlock(&ipt_account_tables_lock);
		return;
	}

	/* decrease table use counter */
	table->use_count--;
	if (table->use_count != 0) {
		/* table is used by other rule, can't remove it */
		if (debug)
			printk(KERN_INFO "ipt_account: destroy() table %s is still used (use_count = %i). Leaving.\n", table->name, table->use_count);
		spin_unlock(&ipt_account_tables_lock);
		return;
	}

	/* table is not used by any other tule - remove it */
	if (debug)
		printk(KERN_INFO "ipt_account: destroy() removing table %s (use_count = %i).\n", table->name, table->use_count);
	
	if (last_table)
		last_table->next = table->next;
	else
		ipt_account_tables = table->next;
	
	spin_lock(&table->ip_list_lock);
	spin_unlock(&table->ip_list_lock);
	
	/* remove procfs entries */
	strncpy(proc_entry_name, table->name, IPT_ACCOUNT_NAME_LEN);
	strncat(proc_entry_name, "_64", 4);
	remove_proc_entry(proc_entry_name, proc_net_ipt_account);
	strncpy(proc_entry_name, table->name, IPT_ACCOUNT_NAME_LEN);
	strncat(proc_entry_name, "_32", 4);
	remove_proc_entry(proc_entry_name, proc_net_ipt_account);
	vfree(table->ip_list);
	vfree(table);

	spin_unlock(&ipt_account_tables_lock);
	
	if (debug)
		printk(KERN_INFO "account: destroy() leaving.\n");

	return;
}

static struct ipt_match account_match = {
	{ NULL, NULL },
	"account",
	&match,
	&checkentry,
	&destroy,
	THIS_MODULE
};

static int __init init(void) 
{
	printk(version);	
	if (debug)
		printk(KERN_INFO "account: __init(): ip_list_perms = %i, ip_list_max_mask = %i\n", ip_list_perms, ip_list_max_mask);				
	/* check params */
	if (ip_list_max_mask > 32 || ip_list_max_mask < 0) {
		printk(KERN_ERR "account: Wrong netmask given by ip_list_max_mask parameter (%u). Valid is 32 to 0.\n", ip_list_max_mask);
		return 0;
	}

	ip_list_max_hosts_count = (1 << (32 - ip_list_max_mask)) + 1;
	
	/* create /proc/net/ipt_account directory */
	proc_net_ipt_account = proc_mkdir("ipt_account", proc_net);
	if (!proc_net_ipt_account)
		return -ENOMEM;

	return ipt_register_match(&account_match);
}

/* procedura usuwajca modu */
static void __exit fini(void) 
{
	ipt_unregister_match(&account_match);
	/* remove /proc/net/ipt_account/ directory */
	remove_proc_entry("ipt_account", proc_net);
}

module_init(init);
module_exit(fini);

