Skip to content
Snippets Groups Projects
tcp.c 54.1 KiB
Newer Older
/* -*- mode: c; c-basic-offset: 8; -*-
 *
 * vim: noexpandtab sw=8 ts=8 sts=0:
 *
 * Copyright (C) 2004 Oracle.  All rights reserved.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public
 * License along with this program; if not, write to the
 * Free Software Foundation, Inc., 59 Temple Place - Suite 330,
 * Boston, MA 021110-1307, USA.
 *
 * ----
 *
 * Callers for this were originally written against a very simple synchronus
 * API.  This implementation reflects those simple callers.  Some day I'm sure
 * we'll need to move to a more robust posting/callback mechanism.
 *
 * Transmit calls pass in kernel virtual addresses and block copying this into
 * the socket's tx buffers via a usual blocking sendmsg.  They'll block waiting
 * for a failed socket to timeout.  TX callers can also pass in a poniter to an
 * 'int' which gets filled with an errno off the wire in response to the
 * message they send.
 *
 * Handlers for unsolicited messages are registered.  Each socket has a page
 * that incoming data is copied into.  First the header, then the data.
 * Handlers are called from only one thread with a reference to this per-socket
 * page.  This page is destroyed after the handler call, so it can't be
 * referenced beyond the call.  Handlers may block but are discouraged from
 * doing so.
 *
 * Any framing errors (bad magic, large payload lengths) close a connection.
 *
 * Our sock_container holds the state we associate with a socket.  It's current
 * framing state is held there as well as the refcounting we do around when it
 * is safe to tear down the socket.  The socket is only finally torn down from
 * the container when the container loses all of its references -- so as long
 * as you hold a ref on the container you can trust that the socket is valid
 * for use with kernel socket APIs.
 *
 * Connections are initiated between a pair of nodes when the node with the
 * higher node number gets a heartbeat callback which indicates that the lower
 * numbered node has started heartbeating.  The lower numbered node is passive
 * and only accepts the connection if the higher numbered node is heartbeating.
 */

#include <linux/kernel.h>
#include <linux/jiffies.h>
#include <linux/slab.h>
#include <linux/idr.h>
#include <linux/kref.h>
#include <net/tcp.h>

#include <asm/uaccess.h>

#include "heartbeat.h"
#include "tcp.h"
#include "nodemanager.h"
#define MLOG_MASK_PREFIX ML_TCP
#include "masklog.h"
#include "quorum.h"

#include "tcp_internal.h"

/* 
 * The linux network stack isn't sparse endian clean.. It has macros like
 * ntohs() which perform the endian checks and structs like sockaddr_in
 * which aren't annotated.  So __force is found here to get the build
 * clean.  When they emerge from the dark ages and annotate the code
 * we can remove these.
 */

#define SC_NODEF_FMT "node %s (num %u) at %u.%u.%u.%u:%u"
#define SC_NODEF_ARGS(sc) sc->sc_node->nd_name, sc->sc_node->nd_num,	\
			  NIPQUAD(sc->sc_node->nd_ipv4_address),	\
			  ntohs(sc->sc_node->nd_ipv4_port)

/*
 * In the following two log macros, the whitespace after the ',' just
 * before ##args is intentional. Otherwise, gcc 2.95 will eat the
 * previous token if args expands to nothing.
 */
#define msglog(hdr, fmt, args...) do {					\
	typeof(hdr) __hdr = (hdr);					\
	mlog(ML_MSG, "[mag %u len %u typ %u stat %d sys_stat %d "	\
	     "key %08x num %u] " fmt,					\
	     be16_to_cpu(__hdr->magic), be16_to_cpu(__hdr->data_len), 	\
	     be16_to_cpu(__hdr->msg_type), be32_to_cpu(__hdr->status),	\
	     be32_to_cpu(__hdr->sys_status), be32_to_cpu(__hdr->key),	\
	     be32_to_cpu(__hdr->msg_num) ,  ##args);			\
} while (0)

#define sclog(sc, fmt, args...) do {					\
	typeof(sc) __sc = (sc);						\
	mlog(ML_SOCKET, "[sc %p refs %d sock %p node %u page %p "	\
	     "pg_off %zu] " fmt, __sc,					\
	     atomic_read(&__sc->sc_kref.refcount), __sc->sc_sock,	\
	    __sc->sc_node->nd_num, __sc->sc_page, __sc->sc_page_off ,	\
	    ##args);							\
} while (0)

static DEFINE_RWLOCK(o2net_handler_lock);
static struct rb_root o2net_handler_tree = RB_ROOT;

static struct o2net_node o2net_nodes[O2NM_MAX_NODES];

/* XXX someday we'll need better accounting */
static struct socket *o2net_listen_sock = NULL;

/*
 * listen work is only queued by the listening socket callbacks on the
 * o2net_wq.  teardown detaches the callbacks before destroying the workqueue.
 * quorum work is queued as sock containers are shutdown.. stop_listening
 * tears down all the node's sock containers, preventing future shutdowns
 * and queued quroum work, before canceling delayed quorum work and
 * destroying the work queue.
 */
static struct workqueue_struct *o2net_wq;
static struct work_struct o2net_listen_work;

static struct o2hb_callback_func o2net_hb_up, o2net_hb_down;
#define O2NET_HB_PRI 0x1

static struct o2net_handshake *o2net_hand;
static struct o2net_msg *o2net_keep_req, *o2net_keep_resp;

static int o2net_sys_err_translations[O2NET_ERR_MAX] =
		{[O2NET_ERR_NONE]	= 0,
		 [O2NET_ERR_NO_HNDLR]	= -ENOPROTOOPT,
		 [O2NET_ERR_OVERFLOW]	= -EOVERFLOW,
		 [O2NET_ERR_DIED]	= -EHOSTDOWN,};

/* can't quite avoid *all* internal declarations :/ */
static void o2net_sc_connect_completed(struct work_struct *work);
static void o2net_rx_until_empty(struct work_struct *work);
static void o2net_shutdown_sc(struct work_struct *work);
static void o2net_listen_data_ready(struct sock *sk, int bytes);
static void o2net_sc_send_keep_req(struct work_struct *work);
static void o2net_idle_timer(unsigned long data);
static void o2net_sc_postpone_idle(struct o2net_sock_container *sc);
static void o2net_sc_reset_idle_timer(struct o2net_sock_container *sc);

/*
 * FIXME: These should use to_o2nm_cluster_from_node(), but we end up
 * losing our parent link to the cluster during shutdown. This can be
 * solved by adding a pre-removal callback to configfs, or passing
 * around the cluster with the node. -jeffm
 */
static inline int o2net_reconnect_delay(struct o2nm_node *node)
{
	return o2nm_single_cluster->cl_reconnect_delay_ms;
}

static inline int o2net_keepalive_delay(struct o2nm_node *node)
{
	return o2nm_single_cluster->cl_keepalive_delay_ms;
}

static inline int o2net_idle_timeout(struct o2nm_node *node)
{
	return o2nm_single_cluster->cl_idle_timeout_ms;
}

static inline int o2net_sys_err_to_errno(enum o2net_system_error err)
{
	int trans;
	BUG_ON(err >= O2NET_ERR_MAX);
	trans = o2net_sys_err_translations[err];

	/* Just in case we mess up the translation table above */
	BUG_ON(err != O2NET_ERR_NONE && trans == 0);
	return trans;
}

static struct o2net_node * o2net_nn_from_num(u8 node_num)
{
	BUG_ON(node_num >= ARRAY_SIZE(o2net_nodes));
	return &o2net_nodes[node_num];
}

static u8 o2net_num_from_nn(struct o2net_node *nn)
{
	BUG_ON(nn == NULL);
	return nn - o2net_nodes;
}

/* ------------------------------------------------------------ */

static int o2net_prep_nsw(struct o2net_node *nn, struct o2net_status_wait *nsw)
{
	int ret = 0;

	do {
		if (!idr_pre_get(&nn->nn_status_idr, GFP_ATOMIC)) {
			ret = -EAGAIN;
			break;
		}
		spin_lock(&nn->nn_lock);
		ret = idr_get_new(&nn->nn_status_idr, nsw, &nsw->ns_id);
		if (ret == 0)
			list_add_tail(&nsw->ns_node_item,
				      &nn->nn_status_list);
		spin_unlock(&nn->nn_lock);
	} while (ret == -EAGAIN);

	if (ret == 0)  {
		init_waitqueue_head(&nsw->ns_wq);
		nsw->ns_sys_status = O2NET_ERR_NONE;
		nsw->ns_status = 0;
	}

	return ret;
}

static void o2net_complete_nsw_locked(struct o2net_node *nn,
				      struct o2net_status_wait *nsw,
				      enum o2net_system_error sys_status,
				      s32 status)
{
	assert_spin_locked(&nn->nn_lock);

	if (!list_empty(&nsw->ns_node_item)) {
		list_del_init(&nsw->ns_node_item);
		nsw->ns_sys_status = sys_status;
		nsw->ns_status = status;
		idr_remove(&nn->nn_status_idr, nsw->ns_id);
		wake_up(&nsw->ns_wq);
	}
}

static void o2net_complete_nsw(struct o2net_node *nn,
			       struct o2net_status_wait *nsw,
			       u64 id, enum o2net_system_error sys_status,
			       s32 status)
{
	spin_lock(&nn->nn_lock);
	if (nsw == NULL) {
		if (id > INT_MAX)
			goto out;

		nsw = idr_find(&nn->nn_status_idr, id);
		if (nsw == NULL)
			goto out;
	}

	o2net_complete_nsw_locked(nn, nsw, sys_status, status);

out:
	spin_unlock(&nn->nn_lock);
	return;
}

static void o2net_complete_nodes_nsw(struct o2net_node *nn)
{
	struct list_head *iter, *tmp;
	unsigned int num_kills = 0;
	struct o2net_status_wait *nsw;

	assert_spin_locked(&nn->nn_lock);

	list_for_each_safe(iter, tmp, &nn->nn_status_list) {
		nsw = list_entry(iter, struct o2net_status_wait, ns_node_item);
		o2net_complete_nsw_locked(nn, nsw, O2NET_ERR_DIED, 0);
		num_kills++;
	}

	mlog(0, "completed %d messages for node %u\n", num_kills,
	     o2net_num_from_nn(nn));
}

static int o2net_nsw_completed(struct o2net_node *nn,
			       struct o2net_status_wait *nsw)
{
	int completed;
	spin_lock(&nn->nn_lock);
	completed = list_empty(&nsw->ns_node_item);
	spin_unlock(&nn->nn_lock);
	return completed;
}

/* ------------------------------------------------------------ */

static void sc_kref_release(struct kref *kref)
{
	struct o2net_sock_container *sc = container_of(kref,
					struct o2net_sock_container, sc_kref);
	BUG_ON(timer_pending(&sc->sc_idle_timeout));

	sclog(sc, "releasing\n");

	if (sc->sc_sock) {
		sock_release(sc->sc_sock);
		sc->sc_sock = NULL;
	}

	o2nm_node_put(sc->sc_node);
	sc->sc_node = NULL;

	kfree(sc);
}

static void sc_put(struct o2net_sock_container *sc)
{
	sclog(sc, "put\n");
	kref_put(&sc->sc_kref, sc_kref_release);
}
static void sc_get(struct o2net_sock_container *sc)
{
	sclog(sc, "get\n");
	kref_get(&sc->sc_kref);
}
static struct o2net_sock_container *sc_alloc(struct o2nm_node *node)
{
	struct o2net_sock_container *sc, *ret = NULL;
	struct page *page = NULL;

	page = alloc_page(GFP_NOFS);
	sc = kcalloc(1, sizeof(*sc), GFP_NOFS);
	if (sc == NULL || page == NULL)
		goto out;

	kref_init(&sc->sc_kref);
	o2nm_node_get(node);
	sc->sc_node = node;

	INIT_WORK(&sc->sc_connect_work, o2net_sc_connect_completed);
	INIT_WORK(&sc->sc_rx_work, o2net_rx_until_empty);
	INIT_WORK(&sc->sc_shutdown_work, o2net_shutdown_sc);
	INIT_DELAYED_WORK(&sc->sc_keepalive_work, o2net_sc_send_keep_req);

	init_timer(&sc->sc_idle_timeout);
	sc->sc_idle_timeout.function = o2net_idle_timer;
	sc->sc_idle_timeout.data = (unsigned long)sc;

	sclog(sc, "alloced\n");

	ret = sc;
	sc->sc_page = page;
	sc = NULL;
	page = NULL;

out:
	if (page)
		__free_page(page);
	kfree(sc);

	return ret;
}

/* ------------------------------------------------------------ */

static void o2net_sc_queue_work(struct o2net_sock_container *sc,
				struct work_struct *work)
{
	sc_get(sc);
	if (!queue_work(o2net_wq, work))
		sc_put(sc);
}
static void o2net_sc_queue_delayed_work(struct o2net_sock_container *sc,
					struct delayed_work *work,
					int delay)
{
	sc_get(sc);
	if (!queue_delayed_work(o2net_wq, work, delay))
		sc_put(sc);
}
static void o2net_sc_cancel_delayed_work(struct o2net_sock_container *sc,
					 struct delayed_work *work)
{
	if (cancel_delayed_work(work))
		sc_put(sc);
}

static atomic_t o2net_connected_peers = ATOMIC_INIT(0);

int o2net_num_connected_peers(void)
{
	return atomic_read(&o2net_connected_peers);
}

static void o2net_set_nn_state(struct o2net_node *nn,
			       struct o2net_sock_container *sc,
			       unsigned valid, int err)
{
	int was_valid = nn->nn_sc_valid;
	int was_err = nn->nn_persistent_error;
	struct o2net_sock_container *old_sc = nn->nn_sc;

	assert_spin_locked(&nn->nn_lock);

	if (old_sc && !sc)
		atomic_dec(&o2net_connected_peers);
	else if (!old_sc && sc)
		atomic_inc(&o2net_connected_peers);

	/* the node num comparison and single connect/accept path should stop
	 * an non-null sc from being overwritten with another */
	BUG_ON(sc && nn->nn_sc && nn->nn_sc != sc);
	mlog_bug_on_msg(err && valid, "err %d valid %u\n", err, valid);
	mlog_bug_on_msg(valid && !sc, "valid %u sc %p\n", valid, sc);

	/* we won't reconnect after our valid conn goes away for
	 * this hb iteration.. here so it shows up in the logs */
	if (was_valid && !valid && err == 0)
		err = -ENOTCONN;

	mlog(ML_CONN, "node %u sc: %p -> %p, valid %u -> %u, err %d -> %d\n",
	     o2net_num_from_nn(nn), nn->nn_sc, sc, nn->nn_sc_valid, valid,
	     nn->nn_persistent_error, err);

	nn->nn_sc = sc;
	nn->nn_sc_valid = valid ? 1 : 0;
	nn->nn_persistent_error = err;

	/* mirrors o2net_tx_can_proceed() */
	if (nn->nn_persistent_error || nn->nn_sc_valid)
		wake_up(&nn->nn_sc_wq);

	if (!was_err && nn->nn_persistent_error) {
		o2quo_conn_err(o2net_num_from_nn(nn));
		queue_delayed_work(o2net_wq, &nn->nn_still_up,
				   msecs_to_jiffies(O2NET_QUORUM_DELAY_MS));
	}

	if (was_valid && !valid) {
		printk(KERN_INFO "o2net: no longer connected to "
		       SC_NODEF_FMT "\n", SC_NODEF_ARGS(old_sc));
		o2net_complete_nodes_nsw(nn);
	}

	if (!was_valid && valid) {
		o2quo_conn_up(o2net_num_from_nn(nn));
		/* this is a bit of a hack.  we only try reconnecting
		 * when heartbeating starts until we get a connection.
		 * if that connection then dies we don't try reconnecting.
		 * the only way to start connecting again is to down
		 * heartbeat and bring it back up. */
		cancel_delayed_work(&nn->nn_connect_expired);
		printk(KERN_INFO "o2net: %s " SC_NODEF_FMT "\n",
		       o2nm_this_node() > sc->sc_node->nd_num ?
		       		"connected to" : "accepted connection from",
		       SC_NODEF_ARGS(sc));
	}

	/* trigger the connecting worker func as long as we're not valid,
	 * it will back off if it shouldn't connect.  This can be called
	 * from node config teardown and so needs to be careful about
	 * the work queue actually being up. */
	if (!valid && o2net_wq) {
		unsigned long delay;
		/* delay if we're withing a RECONNECT_DELAY of the
		 * last attempt */
		delay = (nn->nn_last_connect_attempt +
			 msecs_to_jiffies(o2net_reconnect_delay(sc->sc_node)))
		if (delay > msecs_to_jiffies(o2net_reconnect_delay(sc->sc_node)))
			delay = 0;
		mlog(ML_CONN, "queueing conn attempt in %lu jiffies\n", delay);
		queue_delayed_work(o2net_wq, &nn->nn_connect_work, delay);
	}

	/* keep track of the nn's sc ref for the caller */
	if ((old_sc == NULL) && sc)
		sc_get(sc);
	if (old_sc && (old_sc != sc)) {
		o2net_sc_queue_work(old_sc, &old_sc->sc_shutdown_work);
		sc_put(old_sc);
	}
}

/* see o2net_register_callbacks() */
static void o2net_data_ready(struct sock *sk, int bytes)
{
	void (*ready)(struct sock *sk, int bytes);

	read_lock(&sk->sk_callback_lock);
	if (sk->sk_user_data) {
		struct o2net_sock_container *sc = sk->sk_user_data;
		sclog(sc, "data_ready hit\n");
		do_gettimeofday(&sc->sc_tv_data_ready);
		o2net_sc_queue_work(sc, &sc->sc_rx_work);
		ready = sc->sc_data_ready;
	} else {
		ready = sk->sk_data_ready;
	}
	read_unlock(&sk->sk_callback_lock);

	ready(sk, bytes);
}

/* see o2net_register_callbacks() */
static void o2net_state_change(struct sock *sk)
{
	void (*state_change)(struct sock *sk);
	struct o2net_sock_container *sc;

	read_lock(&sk->sk_callback_lock);
	sc = sk->sk_user_data;
	if (sc == NULL) {
		state_change = sk->sk_state_change;
		goto out;
	}

	sclog(sc, "state_change to %d\n", sk->sk_state);

	state_change = sc->sc_state_change;

	switch(sk->sk_state) {
		/* ignore connecting sockets as they make progress */
		case TCP_SYN_SENT:
		case TCP_SYN_RECV:
			break;
		case TCP_ESTABLISHED:
			o2net_sc_queue_work(sc, &sc->sc_connect_work);
			break;
		default:
			o2net_sc_queue_work(sc, &sc->sc_shutdown_work);
			break;
	}
out:
	read_unlock(&sk->sk_callback_lock);
	state_change(sk);
}

/*
 * we register callbacks so we can queue work on events before calling
 * the original callbacks.  our callbacks our careful to test user_data
 * to discover when they've reaced with o2net_unregister_callbacks().
 */
static void o2net_register_callbacks(struct sock *sk,
				     struct o2net_sock_container *sc)
{
	write_lock_bh(&sk->sk_callback_lock);

	/* accepted sockets inherit the old listen socket data ready */
	if (sk->sk_data_ready == o2net_listen_data_ready) {
		sk->sk_data_ready = sk->sk_user_data;
		sk->sk_user_data = NULL;
	}

	BUG_ON(sk->sk_user_data != NULL);
	sk->sk_user_data = sc;
	sc_get(sc);

	sc->sc_data_ready = sk->sk_data_ready;
	sc->sc_state_change = sk->sk_state_change;
	sk->sk_data_ready = o2net_data_ready;
	sk->sk_state_change = o2net_state_change;

	write_unlock_bh(&sk->sk_callback_lock);
}

static int o2net_unregister_callbacks(struct sock *sk,
			           struct o2net_sock_container *sc)
{
	int ret = 0;

	write_lock_bh(&sk->sk_callback_lock);
	if (sk->sk_user_data == sc) {
		ret = 1;
		sk->sk_user_data = NULL;
		sk->sk_data_ready = sc->sc_data_ready;
		sk->sk_state_change = sc->sc_state_change;
	}
	write_unlock_bh(&sk->sk_callback_lock);

	return ret;
}

/*
 * this is a little helper that is called by callers who have seen a problem
 * with an sc and want to detach it from the nn if someone already hasn't beat
 * them to it.  if an error is given then the shutdown will be persistent
 * and pending transmits will be canceled.
 */
static void o2net_ensure_shutdown(struct o2net_node *nn,
			           struct o2net_sock_container *sc,
				   int err)
{
	spin_lock(&nn->nn_lock);
	if (nn->nn_sc == sc)
		o2net_set_nn_state(nn, NULL, 0, err);
	spin_unlock(&nn->nn_lock);
}

/*
 * This work queue function performs the blocking parts of socket shutdown.  A
 * few paths lead here.  set_nn_state will trigger this callback if it sees an
 * sc detached from the nn.  state_change will also trigger this callback
 * directly when it sees errors.  In that case we need to call set_nn_state
 * ourselves as state_change couldn't get the nn_lock and call set_nn_state
 * itself.
 */
static void o2net_shutdown_sc(struct work_struct *work)
	struct o2net_sock_container *sc =
		container_of(work, struct o2net_sock_container,
			     sc_shutdown_work);
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000
	struct o2net_node *nn = o2net_nn_from_num(sc->sc_node->nd_num);

	sclog(sc, "shutting down\n");

	/* drop the callbacks ref and call shutdown only once */
	if (o2net_unregister_callbacks(sc->sc_sock->sk, sc)) {
		/* we shouldn't flush as we're in the thread, the
		 * races with pending sc work structs are harmless */
		del_timer_sync(&sc->sc_idle_timeout);
		o2net_sc_cancel_delayed_work(sc, &sc->sc_keepalive_work);
		sc_put(sc);
		sc->sc_sock->ops->shutdown(sc->sc_sock,
					   RCV_SHUTDOWN|SEND_SHUTDOWN);
	}

	/* not fatal so failed connects before the other guy has our
	 * heartbeat can be retried */
	o2net_ensure_shutdown(nn, sc, 0);
	sc_put(sc);
}

/* ------------------------------------------------------------ */

static int o2net_handler_cmp(struct o2net_msg_handler *nmh, u32 msg_type,
			     u32 key)
{
	int ret = memcmp(&nmh->nh_key, &key, sizeof(key));

	if (ret == 0)
		ret = memcmp(&nmh->nh_msg_type, &msg_type, sizeof(msg_type));

	return ret;
}

static struct o2net_msg_handler *
o2net_handler_tree_lookup(u32 msg_type, u32 key, struct rb_node ***ret_p,
			  struct rb_node **ret_parent)
{
        struct rb_node **p = &o2net_handler_tree.rb_node;
        struct rb_node *parent = NULL;
	struct o2net_msg_handler *nmh, *ret = NULL;
	int cmp;

        while (*p) {
                parent = *p;
                nmh = rb_entry(parent, struct o2net_msg_handler, nh_node);
		cmp = o2net_handler_cmp(nmh, msg_type, key);

                if (cmp < 0)
                        p = &(*p)->rb_left;
                else if (cmp > 0)
                        p = &(*p)->rb_right;
                else {
			ret = nmh;
                        break;
		}
        }

        if (ret_p != NULL)
                *ret_p = p;
        if (ret_parent != NULL)
                *ret_parent = parent;

        return ret;
}

static void o2net_handler_kref_release(struct kref *kref)
{
	struct o2net_msg_handler *nmh;
	nmh = container_of(kref, struct o2net_msg_handler, nh_kref);

	kfree(nmh);
}

static void o2net_handler_put(struct o2net_msg_handler *nmh)
{
	kref_put(&nmh->nh_kref, o2net_handler_kref_release);
}

/* max_len is protection for the handler func.  incoming messages won't
 * be given to the handler if their payload is longer than the max. */
int o2net_register_handler(u32 msg_type, u32 key, u32 max_len,
			   o2net_msg_handler_func *func, void *data,
			   struct list_head *unreg_list)
{
	struct o2net_msg_handler *nmh = NULL;
	struct rb_node **p, *parent;
	int ret = 0;

	if (max_len > O2NET_MAX_PAYLOAD_BYTES) {
		mlog(0, "max_len for message handler out of range: %u\n",
			max_len);
		ret = -EINVAL;
		goto out;
	}

	if (!msg_type) {
		mlog(0, "no message type provided: %u, %p\n", msg_type, func);
		ret = -EINVAL;
		goto out;

	}
	if (!func) {
		mlog(0, "no message handler provided: %u, %p\n",
		       msg_type, func);
		ret = -EINVAL;
		goto out;
	}

       	nmh = kcalloc(1, sizeof(struct o2net_msg_handler), GFP_NOFS);
	if (nmh == NULL) {
		ret = -ENOMEM;
		goto out;
	}

	nmh->nh_func = func;
	nmh->nh_func_data = data;
	nmh->nh_msg_type = msg_type;
	nmh->nh_max_len = max_len;
	nmh->nh_key = key;
	/* the tree and list get this ref.. they're both removed in
	 * unregister when this ref is dropped */
	kref_init(&nmh->nh_kref);
	INIT_LIST_HEAD(&nmh->nh_unregister_item);

	write_lock(&o2net_handler_lock);
	if (o2net_handler_tree_lookup(msg_type, key, &p, &parent))
		ret = -EEXIST;
	else {
	        rb_link_node(&nmh->nh_node, parent, p);
		rb_insert_color(&nmh->nh_node, &o2net_handler_tree);
		list_add_tail(&nmh->nh_unregister_item, unreg_list);

		mlog(ML_TCP, "registered handler func %p type %u key %08x\n",
		     func, msg_type, key);
		/* we've had some trouble with handlers seemingly vanishing. */
		mlog_bug_on_msg(o2net_handler_tree_lookup(msg_type, key, &p,
							  &parent) == NULL,
			        "couldn't find handler we *just* registerd "
				"for type %u key %08x\n", msg_type, key);
	}
	write_unlock(&o2net_handler_lock);
	if (ret)
		goto out;

out:
	if (ret)
		kfree(nmh);

	return ret;
}
EXPORT_SYMBOL_GPL(o2net_register_handler);

void o2net_unregister_handler_list(struct list_head *list)
{
	struct list_head *pos, *n;
	struct o2net_msg_handler *nmh;

	write_lock(&o2net_handler_lock);
	list_for_each_safe(pos, n, list) {
		nmh = list_entry(pos, struct o2net_msg_handler,
				 nh_unregister_item);
		mlog(ML_TCP, "unregistering handler func %p type %u key %08x\n",
		     nmh->nh_func, nmh->nh_msg_type, nmh->nh_key);
		rb_erase(&nmh->nh_node, &o2net_handler_tree);
		list_del_init(&nmh->nh_unregister_item);
		kref_put(&nmh->nh_kref, o2net_handler_kref_release);
	}
	write_unlock(&o2net_handler_lock);
}
EXPORT_SYMBOL_GPL(o2net_unregister_handler_list);

static struct o2net_msg_handler *o2net_handler_get(u32 msg_type, u32 key)
{
	struct o2net_msg_handler *nmh;

	read_lock(&o2net_handler_lock);
	nmh = o2net_handler_tree_lookup(msg_type, key, NULL, NULL);
	if (nmh)
		kref_get(&nmh->nh_kref);
	read_unlock(&o2net_handler_lock);

	return nmh;
}

/* ------------------------------------------------------------ */

static int o2net_recv_tcp_msg(struct socket *sock, void *data, size_t len)
{
	int ret;
	mm_segment_t oldfs;
	struct kvec vec = {
		.iov_len = len,
		.iov_base = data,
	};
	struct msghdr msg = {
		.msg_iovlen = 1,
		.msg_iov = (struct iovec *)&vec,
       		.msg_flags = MSG_DONTWAIT,
	};

	oldfs = get_fs();
	set_fs(get_ds());
	ret = sock_recvmsg(sock, &msg, len, msg.msg_flags);
	set_fs(oldfs);

	return ret;
}

static int o2net_send_tcp_msg(struct socket *sock, struct kvec *vec,
			      size_t veclen, size_t total)
{
	int ret;
	mm_segment_t oldfs;
	struct msghdr msg = {
		.msg_iov = (struct iovec *)vec,
		.msg_iovlen = veclen,
	};

	if (sock == NULL) {
		ret = -EINVAL;
		goto out;
	}

	oldfs = get_fs();
	set_fs(get_ds());
	ret = sock_sendmsg(sock, &msg, total);
	set_fs(oldfs);
	if (ret != total) {
		mlog(ML_ERROR, "sendmsg returned %d instead of %zu\n", ret,
		     total);
		if (ret >= 0)
			ret = -EPIPE; /* should be smarter, I bet */
		goto out;
	}

	ret = 0;
out:
	if (ret < 0)
		mlog(0, "returning error: %d\n", ret);
	return ret;
}

static void o2net_sendpage(struct o2net_sock_container *sc,
			   void *kmalloced_virt,
			   size_t size)
{
	struct o2net_node *nn = o2net_nn_from_num(sc->sc_node->nd_num);
	ssize_t ret;


	ret = sc->sc_sock->ops->sendpage(sc->sc_sock,
					 virt_to_page(kmalloced_virt),
					 (long)kmalloced_virt & ~PAGE_MASK,
					 size, MSG_DONTWAIT);
	if (ret != size) {
		mlog(ML_ERROR, "sendpage of size %zu to " SC_NODEF_FMT 
		     " failed with %zd\n", size, SC_NODEF_ARGS(sc), ret);
		o2net_ensure_shutdown(nn, sc, 0);
	}
}

static void o2net_init_msg(struct o2net_msg *msg, u16 data_len, u16 msg_type, u32 key)
{
	memset(msg, 0, sizeof(struct o2net_msg));
	msg->magic = cpu_to_be16(O2NET_MSG_MAGIC);
	msg->data_len = cpu_to_be16(data_len);
	msg->msg_type = cpu_to_be16(msg_type);
	msg->sys_status = cpu_to_be32(O2NET_ERR_NONE);
	msg->status = 0;
	msg->key = cpu_to_be32(key);
}

static int o2net_tx_can_proceed(struct o2net_node *nn,
			        struct o2net_sock_container **sc_ret,
				int *error)
{
	int ret = 0;

	spin_lock(&nn->nn_lock);
	if (nn->nn_persistent_error) {
		ret = 1;
		*sc_ret = NULL;
		*error = nn->nn_persistent_error;
	} else if (nn->nn_sc_valid) {
		kref_get(&nn->nn_sc->sc_kref);

		ret = 1;
		*sc_ret = nn->nn_sc;
		*error = 0;
	}
	spin_unlock(&nn->nn_lock);

	return ret;
}

int o2net_send_message_vec(u32 msg_type, u32 key, struct kvec *caller_vec,
			   size_t caller_veclen, u8 target_node, int *status)
{
	int ret, error = 0;
	struct o2net_msg *msg = NULL;
	size_t veclen, caller_bytes = 0;
	struct kvec *vec = NULL;
	struct o2net_sock_container *sc = NULL;
	struct o2net_node *nn = o2net_nn_from_num(target_node);
	struct o2net_status_wait nsw = {
		.ns_node_item = LIST_HEAD_INIT(nsw.ns_node_item),
	};

	if (o2net_wq == NULL) {
		mlog(0, "attempt to tx without o2netd running\n");
		ret = -ESRCH;
		goto out;
	}

	if (caller_veclen == 0) {
		mlog(0, "bad kvec array length\n");
		ret = -EINVAL;
		goto out;
	}

	caller_bytes = iov_length((struct iovec *)caller_vec, caller_veclen);
	if (caller_bytes > O2NET_MAX_PAYLOAD_BYTES) {
		mlog(0, "total payload len %zu too large\n", caller_bytes);
		ret = -EINVAL;
		goto out;
	}

	if (target_node == o2nm_this_node()) {
		ret = -ELOOP;
		goto out;
	}

	ret = wait_event_interruptible(nn->nn_sc_wq,
				       o2net_tx_can_proceed(nn, &sc, &error));
	if (!ret && error)
		ret = error;
	if (ret)
		goto out;

	veclen = caller_veclen + 1;
	vec = kmalloc(sizeof(struct kvec) * veclen, GFP_ATOMIC);
	if (vec == NULL) {
		mlog(0, "failed to %zu element kvec!\n", veclen);
		ret = -ENOMEM;
		goto out;
	}

	msg = kmalloc(sizeof(struct o2net_msg), GFP_ATOMIC);
	if (!msg) {
		mlog(0, "failed to allocate a o2net_msg!\n");
		ret = -ENOMEM;
		goto out;
	}

	o2net_init_msg(msg, caller_bytes, msg_type, key);

	vec[0].iov_len = sizeof(struct o2net_msg);
	vec[0].iov_base = msg;
	memcpy(&vec[1], caller_vec, caller_veclen * sizeof(struct kvec));

	ret = o2net_prep_nsw(nn, &nsw);
	if (ret)
		goto out;

	msg->msg_num = cpu_to_be32(nsw.ns_id);

	/* finally, convert the message header to network byte-order
	 * and send */
	ret = o2net_send_tcp_msg(sc->sc_sock, vec, veclen,
				 sizeof(struct o2net_msg) + caller_bytes);
	msglog(msg, "sending returned %d\n", ret);
	if (ret < 0) {
		mlog(0, "error returned from o2net_send_tcp_msg=%d\n", ret);
		goto out;
	}

	/* wait on other node's handler */
	wait_event(nsw.ns_wq, o2net_nsw_completed(nn, &nsw));

	/* Note that we avoid overwriting the callers status return
	 * variable if a system error was reported on the other
	 * side. Callers beware. */
	ret = o2net_sys_err_to_errno(nsw.ns_sys_status);
	if (status && !ret)
		*status = nsw.ns_status;

	mlog(0, "woken, returning system status %d, user status %d\n",
	     ret, nsw.ns_status);
out:
	if (sc)
		sc_put(sc);
	if (vec)