/* Copyright 2007 by Kim Minh Kaplan
 *
 * greyfix.c version 0.3.2
 *
 * Postfix policy daemon designed to prevent spam using the
 * greylisting method.
 *
 * Greylisting: http://projects.puremagic.com/greylisting/
 * Postfix: http://www.postfix.org/
 * Kim Minh Kaplan: http://www.kim-minh.com/
 *
 */
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include <assert.h>
#include <errno.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

#include <unistd.h>
#include <syslog.h>
#include <sys/stat.h>

#include <db.h>

/**
 * This determines how many seconds we will block inbound mail that is
 * from a previously unknown (ip, from, to) triplet.  If it is set to
 * zero, incoming mail association will be learned, but no deliveries
 * will be tempfailed.  Use a setting of zero with caution, as it will
 * learn spammers as well as legitimate senders.
 **/
#define DELAY_MAIL_SECS (58 * 60)	/* 58 minutes */
/**
 * This determines how many seconds of life are given to a record that
 * is created from a new mail [ip,from,to] triplet.  Note that the
 * window created by this setting for passing mails is reduced by the
 * amount set for DELAY_MAIL_SECS.  NOTE: See Also:
 * update_record_life and update_record_life_secs.
 */
#define AUTO_RECORD_LIFE_SECS (5 * 3600) /* 5 hours */
/**
 * How much life (in secs) to give to a record we are updating from an
 * allowed (passed) email.
 *
 * The default is 36 days, which should be enough to handle messages
 * that may only be sent once a month, or on things like the first
 * monday of the month (which sometimes means 5 weeks).  Plus, we add
 * a day for a delivery buffer.
 */
#define UPDATE_RECORD_LIFE_SECS (36 * 24 * 3600)

#define DEF_DB_HOME DATA_STATE_DIR"/"PACKAGE

#define DB_STAT_NAME "stats.db"
#define DB_FILE_NAME "triplets.db"
#define SEP '\000'

/* Counter for triplets that are expired and never passed anything
 * through */
#define STAT_BLOCKED "blocked"

#define prefixp(s,p) (!strncmp((s),(p),sizeof(p)-1))

struct triplet_data {
    time_t create_time;
    time_t access_time;
    unsigned long block_count;
    unsigned long pass_count;
};

static const char str_smtp[] = "SMTP";
static const char str_esmtp[] = "ESMTP";
static const char str_rcpt[] = "RCPT";
static const char str_action[] = "action=";

static const char *progname;
static const char *db_home = DEF_DB_HOME;
static int opt_verbose = 0;

static DB_ENV *dbenv = 0;
static DB *db = 0;
static DB *statsdb = 0;

static char *policy_request = 0;
static size_t policy_request_size = 0;
static size_t policy_request_fill = 0;

static DBT dbkey = { 0 };
static DBT dbdata = { 0 };
static struct triplet_data triplet_data;

static int debug_me = 0;
static unsigned long greylist_delay = DELAY_MAIL_SECS;
static unsigned long bloc_max_idle = AUTO_RECORD_LIFE_SECS;
static unsigned long pass_max_idle = UPDATE_RECORD_LIFE_SECS;
/* As we store IP addresses in Postfix's format, to obtain the network
 address we first strip `ipv4_network_strip_bytes' numbers (between 0
 and 4) then we apply `ipv4_network_mask' on the last byte. */
static unsigned int ipv4_network_strip_bytes;
static unsigned int ipv4_network_mask;

/**********************************************************************
 * Berkeley DB routines
 */
static void
log_db_error(const char *msg, int error)
{
    syslog(LOG_ERR, "%s: %s", msg, db_strerror(error));
}

static void
db_errcall_fcn(const DB_ENV *dbenv, const char *errpfx, const char *msg)
{
    syslog(LOG_ERR, "%s: %s", errpfx ? errpfx : "Berkeley DB", msg);
}

static int
prepare_env()
{
    int rc;
    rc = db_env_create(&dbenv, 0);
    if (rc)
	log_db_error("db_env_create", rc);
    else {
	dbenv->set_errcall(dbenv, db_errcall_fcn);
	rc = dbenv->open(dbenv, db_home,
			 DB_INIT_CDB | DB_INIT_MPOOL | DB_CREATE, 0);
	if (rc)
	    log_db_error("dbenv->open", rc);
    }
    return rc;
}

static int
prepare_db()
{
    int rc;
    rc = db_create(&db, dbenv, 0);
    if (rc)
	log_db_error("db_create", rc);
    else {
	rc = db->open(db, NULL, DB_FILE_NAME, NULL, DB_BTREE,
		      DB_CREATE, 0644);
	if (rc)
	    log_db_error("db->open", rc);
	else {
	    rc = db_create(&statsdb, dbenv, 0);
	    if (rc)
		syslog(LOG_ERR, "BDB-%d: db_create statsdb: %s",
		       rc, db_strerror(rc));
	    else {
		rc = db->open(statsdb, NULL, DB_STAT_NAME, NULL, DB_BTREE,
			      DB_CREATE, 0644);
		if (rc)
		    syslog(LOG_ERR, "BDB-%d: db->open %s: %s",
			   rc, DB_STAT_NAME, db_strerror(rc));
	    }
	}
    }
    return rc;
}

static int
initialize()
{
    int rc;
    char *version;
    int major, minor, patch;
    version = db_version(&major, &minor, &patch);
    if (DB_VERSION_MAJOR != major || DB_VERSION_MINOR != minor) {
	syslog(LOG_ERR,
	       "This daemon was compiled with " DB_VERSION_STRING " (%d.%d.%d) definitions "
	       "but it is using %s (%d.%d.%d).  This will not work!  "
	       "Check that the version of the developpement files for Berkeley DB "
	       "match the version that used.",
	       DB_VERSION_MAJOR, DB_VERSION_MINOR, DB_VERSION_PATCH,
	       version, major, minor, patch);
	abort();
    }
    if (DB_VERSION_PATCH != patch && (opt_verbose || debug_me))
	syslog(LOG_INFO,
	       "Compiled with " DB_VERSION_STRING " (%d.%d.%d) definitions.  "
	       "Running with %s (%d.%d.%d).",
	       DB_VERSION_MAJOR, DB_VERSION_MINOR, DB_VERSION_PATCH,
	       version, major, minor, patch);
    else if (debug_me)
	syslog(LOG_DEBUG,
	       "This daemon was compiled with " DB_VERSION_STRING " (%d.%d.%d) definitions.",
	       DB_VERSION_MAJOR, DB_VERSION_MINOR, DB_VERSION_PATCH);
    dbdata.data = &triplet_data;
    dbdata.size = sizeof triplet_data;
    dbdata.ulen = sizeof triplet_data;
    dbdata.flags = DB_DBT_USERMEM;
    rc = prepare_env();
    if (!rc)
	rc = prepare_db();
    return rc;
}

static void
cleanup()
{
    int rc;
    if (dbkey.data)
	free(dbkey.data);
    if (policy_request)
	free(policy_request);
    if (statsdb) {
	rc = statsdb->close(statsdb, 0);
	statsdb = 0;
	if (rc)
	    syslog(LOG_ERR, "BDB-%d: statsdb close: %s", rc, db_strerror(rc));
    }
    if (db) {
	rc = db->close(db, 0);
	db = 0;
	if (rc)
	    log_db_error("DB close", rc);
    }
    if (dbenv) {
	rc = dbenv->close(dbenv, 0);
	dbenv = 0;
	if (rc)
	    log_db_error("DB_ENV close", rc);
	/* Clean up the environment so that upgrading is easier. */
	rc = db_env_create(&dbenv, 0);
	if (rc)
	    log_db_error("db_env_create failed during cleanup", rc);
	else {
	    rc = dbenv->remove(dbenv, db_home, 0);
	    if (rc)
		syslog(LOG_WARNING, "BDB-%d: db_env->remove failed: %s",
		       rc, db_strerror(rc));
	}
    }
    if (debug_me)
	syslog(LOG_DEBUG, "Cleaned");
}

static void
fatal(const char *msg)
{
    int err = errno;;
    cleanup();
    syslog(LOG_ERR, "fatal: %s: %s", msg, db_strerror(err));
    abort();
}

static void *
xrealloc(void *ptr, size_t size)
{
    void *newptr = realloc(ptr, size);
    if (newptr)
	return newptr;
    fatal("Out of memory in xrealloc");
}

static void
build_triplet_key(const char *ip, const char *from, const char *to)
{
    int ipv4 = 1;
    const char *endip = strchr(ip, '\n'),
	*endfrom = strchr(from, '\n'),
	*endto = strchr(to, '\n');
    size_t lenfrom = endfrom - from,
	lento = endto - to;
    size_t lenip, total;
    char *buf;
    /* Mangle the IP address so that only the required prefix is used */
    if (ipv4_network_strip_bytes > 0) {
	const char *p = endip;
	unsigned int i = ipv4_network_strip_bytes;
	while (i && --p > ip)
	    if (*p == '.')
		i--;
	if (i <= 1)
	    endip = p;
	else {
	    ipv4 = 0;
	    syslog(LOG_ERR, "Could not apply network strip");
	}
    }
    lenip = endip - ip,
    total = lenip + lenfrom + lento + 2;
    if (dbkey.ulen < total) {
	dbkey.data = xrealloc(dbkey.data, total);
	dbkey.ulen = total;
	dbkey.flags = DB_DBT_USERMEM;
    }
    buf = (char*)dbkey.data;
    if (ipv4 && ipv4_network_mask != 0xffU) {
	/* Mask the last octet of the IP address */
	char *q;
	unsigned long byte;
	const char *p = endip;
	while (--p > ip)
	    if (*p == '.')
		break;
	if (*p == '.')
	    p++;
	byte = strtoul(p, &q, 10);
	if (p != q && q <= endip && byte < 256U) {
	    size_t n = p - ip;
	    memcpy(buf, ip, n);
	    buf += n;
	    /* XXX the byte we are subsituting can only be smaller
	       than the original so no additional memory is needed. */
	    n = sprintf(buf, "%u", byte & ipv4_network_mask);
	    buf += n;
	    assert(buf - (char *)dbkey.data <= lenip);
	}
	else
	    ipv4 = 0;
    }
    if (! ipv4 || ipv4_network_mask == 0xffU) {
	memcpy(buf, ip, lenip);
	buf += lenip;
    }
    *buf++ = 0;
    if (debug_me)
	syslog(LOG_DEBUG, "triplet effective IP: %s", dbkey.data);
    memcpy(buf, from, lenfrom);
    buf += lenfrom;
    *buf++ = 0;
    memcpy(buf, to, lento);
    buf += lento;
    dbkey.size = buf - (char *)dbkey.data;
    assert(dbkey.size <= dbkey.ulen);
}

static void
touch_data()
{
    if (time(&triplet_data.access_time) == (time_t)-1)
	fatal("time failed");
}

static void
build_data()
{
    touch_data();
    triplet_data.create_time = triplet_data.access_time;
    triplet_data.block_count = 0;
    triplet_data.pass_count = 0;
}

/**********************************************************************
 * See SMTPD_POLICY_README
 */

static void
safe_writes(int fd, const char *s)
{
    size_t len = strlen(s);
    while (len) {
	int n = write(fd, s, len);
	if (n > 0)
	    len -= n;
	else if (n == 0)
	    syslog(LOG_ERR, "Retrying due to empty write");
	else if (errno == EINTR && debug_me)
	    syslog(LOG_DEBUG, "Retrying write: %s", strerror(errno));
    }
}

/* XXX Assumes there is an empty line marking the end of request */
static char *
find_attribute(const char *name)
{
    char *p;
    size_t nlen;
    nlen = strlen(name);
    for (p = policy_request; *p != '\n'; p = strchr(p, '\n') + 1)
	if (strncmp(name, p, nlen) == 0 && p[nlen] == '=')
	    return p + nlen + 1;
    return NULL;
}

static char *
find_empty_line(const char *p, const char *const endp)
{
    while (p < endp && (p = memchr(p, '\n', endp - p))) {
	p++;
	if (p < endp && *p == '\n')
	    return (char *) p + 1;
    }
    return NULL;
}

/* Find the end of request marker */
static char *
find_eor()
{
    return find_empty_line(policy_request,
			   policy_request + policy_request_fill);
}

/* Forget the oldest policy request */
static void
forget_policy_request()
{
    const char *eor = find_eor();
    if (eor) {
	policy_request_fill -= eor - policy_request;
	memmove(policy_request, eor, policy_request_fill);
    }
    else
	policy_request_fill = 0;
}

/* Read in a new SMTPD access policy request */
static const char *
read_policy_request(int in)
{
    forget_policy_request();
    while (! find_eor()) {
	size_t wanted;
	int n;
	/* Make sure there is some room to read data */
	if (policy_request_fill == policy_request_size) {
	    if (policy_request_size)
		policy_request_size *= 2;
	    else
		policy_request_size = BUFSIZ;
	    if (debug_me)
		syslog(LOG_DEBUG,
		       "allocate %u bytes for request buffer",
		       policy_request_size);
	    policy_request = xrealloc(policy_request, policy_request_size);
	}
	wanted = policy_request_size - policy_request_fill;
	n = read(in, policy_request + policy_request_fill, wanted);
	if (n < 0)
	    log_db_error("read_policy_request failed", n);
	else if (n)
	    policy_request_fill += n;
	else
	    return NULL;
    }
    return policy_request;
}

static void
get_grey_data()
{
    int rc;
    rc = db->get(db, NULL, &dbkey, &dbdata, 0);
    if (rc == DB_NOTFOUND)
	build_data();
    else if (rc) {
	log_db_error("get failed", rc);
	fatal("Exiting");
    }
    else
	touch_data();
}

static void
put_grey_data()
{
    int rc;
    rc = db->put(db, NULL, &dbkey, &dbdata, 0);
    if (rc)
	log_db_error("put", rc);
}

static void
stats_expire()
{
    int rc;
    DBT key;
    DBT data;
    memset(&key, 0, sizeof key);
    memset(&data, 0, sizeof data);
    if (triplet_data.pass_count) {
	;
    }
    else {
	unsigned long count;
	key.data = STAT_BLOCKED;
	key.size = strlen(STAT_BLOCKED);
	data.data = &count;
	data.ulen = sizeof(count);
	data.flags = DB_DBT_USERMEM;
	rc = statsdb->get(statsdb, NULL, &key, &data, 0);
	if (rc == DB_NOTFOUND)
	    count = 1;
	else if (rc) {
	    syslog(LOG_ERR, "BDB-%d: stats_expire block get: %s",
		   rc, db_strerror(rc));
	    return;
	}
	rc = statsdb->put(statsdb, NULL, &key, &data, 0);
	if (rc)
	    syslog(LOG_ERR, "BDB-%d: stats_expire block put: %s",
		   rc, db_strerror(rc));
    }
}

static int
triplet_expired_p(const struct triplet_data *data)
{
    unsigned int age = triplet_data.access_time - triplet_data.create_time;
    return triplet_data.pass_count == 0 && age > bloc_max_idle
	|| triplet_data.pass_count != 0 && age > pass_max_idle;
}

static void
process_smtp_rcpt()
{
    unsigned long delay;
    get_grey_data();
    /* Expire records */
    if (triplet_expired_p(&triplet_data)) {
	if (debug_me)
	    syslog(LOG_DEBUG, "expired record");
	stats_expire();
	triplet_data.create_time = triplet_data.access_time;
    }
    delay = triplet_data.access_time - triplet_data.create_time;
    /* Block inbound mail that is from a previously unknown (ip, from, to) triplet */
    if (delay < greylist_delay) {
	delay = greylist_delay - delay;
	triplet_data.block_count++;
	printf("action=DEFER_IF_PERMIT Greylisted by " PACKAGE_STRING ", "
	       "try again in %lu second%s."
	       "  See http://projects.puremagic.com/greylisting/ for more information.\n",
	       delay, delay == 1 ? "" : "s");
    }
    else if (triplet_data.pass_count++)
	puts("action=DUNNO");
    else
	printf("action=PREPEND X-Greyfix: Greylisted by " PACKAGE_STRING
	       " for %lu second%s\n",
	       delay, delay == 1 ? "" : "s");
    put_grey_data();
}

static void
signal_handler(int signal)
{
    cleanup();
    syslog(LOG_NOTICE, "Received signal %d", signal);
    kill(getpid(), signal);
    exit(-1);
}

static int
optp(const char *opt, const char *short_opt, const char *long_opt)
{
    return strcmp(opt, short_opt) == 0 || strcmp(opt, long_opt) == 0;
}

int
main(int argc, const char **argv)
{
    char *p;
    unsigned int i;
    int rc;
    unsigned long network_prefix = 32;
    unsigned int suffix;
    progname = strrchr(argv[0], '/');
    if (progname)
	progname++;
    else
	progname = argv[0];
    openlog(progname, LOG_PID, LOG_MAIL);
    for (i = 1; i < argc; i++) {
	if (optp(argv[i], "-d", "--debug"))
	    debug_me = 1;
	else if (optp(argv[i], "-v", "--verbose"))
	    opt_verbose++;
	else if (optp(argv[i], "-b", "--bloc-max-idle")) {
	    if (++i >= argc)
		fatal("Missing argument to --bloc-max-idle");
	    bloc_max_idle = strtoul(argv[i], &p, 10);
	    if (*p)
		fatal("Invalid argument to --bloc-max-idle.  "
		      "Integer value expected");
	}
	else if (optp(argv[i], "-g", "--greylist-delay")) {
	    if (++i >= argc)
		fatal("Missing argument to --greylist-delay");
	    greylist_delay = strtoul(argv[i], &p, 10);
	    if (*p)
		fatal("Invalid argument to --greylist-delay.  "
		      "Integer value expected");
	}
	else if (optp(argv[i], "-h", "--home")) {
	    if (++i >= argc)
		fatal("Missing argument to --home");
	    db_home = argv[i];
	}
	else if (optp(argv[i], "-/", "--network-prefix")) {
	    if (++i >= argc)
		fatal("Missing argument to --network-prefix");
	    network_prefix = strtoul(argv[i], &p, 10);
	    if (*p || network_prefix > 32U)
		fatal("Invalid argument to --network-prefix.  "
		      "Integer value between 0 and 32 expected");
	}
	else if (optp(argv[i], "-p", "--pass-max-idle")) {
	    if (++i >= argc)
		fatal("Missing argument to --pass-max-idle");
	    pass_max_idle = strtoul(argv[i], &p, 10);
	    if (*p)
		fatal("Invalid argument to --pass-max-idle.  "
		      "Integer value expected");
	}
	else {
	    fprintf(stderr, "Unknown option \"%s\"\n", argv[i]);
	    exit(EXIT_FAILURE);
	}
    }
    suffix = 32 - network_prefix;
    ipv4_network_strip_bytes = suffix >> 3;
    ipv4_network_mask = 0xffU & ~0U << (suffix & 0x7U);

    if (opt_verbose)
        syslog(LOG_NOTICE, "daemon started");
#ifdef SIGHUP
    signal(SIGHUP, signal_handler);
#endif
#ifdef SIGINT
    signal(SIGINT, signal_handler);
#endif
#ifdef SIGQUIT
    signal(SIGQUIT, signal_handler);
#endif
#ifdef SIGILL
    signal(SIGILL, signal_handler);
#endif
#ifdef SIGABRT
    signal(SIGABRT, signal_handler);
#endif
#ifdef SIGSEGV
    signal(SIGSEGV, signal_handler);
#endif
#ifdef SIGALRM
    signal(SIGALRM, signal_handler);
#endif
#ifdef SIGTERM
    signal(SIGTERM, signal_handler);
#endif

    rc = initialize();
    if (rc) {
	errno = rc;
	fatal("initialization failure");
    }
    while (read_policy_request(0)) {
	const char *protocol = 0, *state = 0, *ip, *from, *to;
	if ((protocol = find_attribute("protocol_name"))
	    && (state = find_attribute("protocol_state"))
	    && (prefixp(protocol, str_smtp) || prefixp(protocol, str_esmtp))
	    && prefixp(state, str_rcpt)
	    && (ip = find_attribute("client_address"))
	    && (from = find_attribute("sender"))
	    && (to = find_attribute("recipient"))) {
	    build_triplet_key(ip, from, to);
	    process_smtp_rcpt();
	}
	else {
	    puts("action=DUNNO");
	    if (debug_me) {
		char *p = 0, *s = 0;
		if (protocol) {
		    p = strchr(protocol, '\n');
		    *p = 0;
		}
		if (state) {
		    s = strchr(state, '\n');
		    *s = 0;
		}
		syslog(LOG_DEBUG,
		       "Ignoring protocol %s state %s",
		       protocol ? protocol : "(not defined)",
		       state ? state : "(not defined)");
		if (p)
		    *p = '\n';
		if (s)
		    *s = '\n';
	    }
	}
	putchar('\n');
	fflush(stdout);
    }
    cleanup();
    if (opt_verbose)
	syslog(LOG_NOTICE, "daemon stopped");
    closelog();
    return 0;
}
