Skip to content
Snippets Groups Projects
tcp.c 58.2 KiB
Newer Older
}

static void o2net_init_msg(struct o2net_msg *msg, u16 data_len, u16 msg_type, u32 key)
{
	memset(msg, 0, sizeof(struct o2net_msg));
	msg->magic = cpu_to_be16(O2NET_MSG_MAGIC);
	msg->data_len = cpu_to_be16(data_len);
	msg->msg_type = cpu_to_be16(msg_type);
	msg->sys_status = cpu_to_be32(O2NET_ERR_NONE);
	msg->status = 0;
	msg->key = cpu_to_be32(key);
}

static int o2net_tx_can_proceed(struct o2net_node *nn,
			        struct o2net_sock_container **sc_ret,
				int *error)
{
	int ret = 0;

	spin_lock(&nn->nn_lock);
	if (nn->nn_persistent_error) {
		ret = 1;
		*sc_ret = NULL;
		*error = nn->nn_persistent_error;
	} else if (nn->nn_sc_valid) {
		kref_get(&nn->nn_sc->sc_kref);

		ret = 1;
		*sc_ret = nn->nn_sc;
		*error = 0;
	}
	spin_unlock(&nn->nn_lock);

	return ret;
}

/* Get a map of all nodes to which this node is currently connected to */
void o2net_fill_node_map(unsigned long *map, unsigned bytes)
{
	struct o2net_sock_container *sc;
	int node, ret;

	BUG_ON(bytes < (BITS_TO_LONGS(O2NM_MAX_NODES) * sizeof(unsigned long)));

	memset(map, 0, bytes);
	for (node = 0; node < O2NM_MAX_NODES; ++node) {
		o2net_tx_can_proceed(o2net_nn_from_num(node), &sc, &ret);
		if (!ret) {
			set_bit(node, map);
			sc_put(sc);
		}
	}
}
EXPORT_SYMBOL_GPL(o2net_fill_node_map);

int o2net_send_message_vec(u32 msg_type, u32 key, struct kvec *caller_vec,
			   size_t caller_veclen, u8 target_node, int *status)
{
	struct o2net_msg *msg = NULL;
	size_t veclen, caller_bytes = 0;
	struct kvec *vec = NULL;
	struct o2net_sock_container *sc = NULL;
	struct o2net_node *nn = o2net_nn_from_num(target_node);
	struct o2net_status_wait nsw = {
		.ns_node_item = LIST_HEAD_INIT(nsw.ns_node_item),
	};
	struct o2net_send_tracking nst;

	o2net_init_nst(&nst, msg_type, key, current, target_node);

	if (o2net_wq == NULL) {
		mlog(0, "attempt to tx without o2netd running\n");
		ret = -ESRCH;
		goto out;
	}

	if (caller_veclen == 0) {
		mlog(0, "bad kvec array length\n");
		ret = -EINVAL;
		goto out;
	}

	caller_bytes = iov_length((struct iovec *)caller_vec, caller_veclen);
	if (caller_bytes > O2NET_MAX_PAYLOAD_BYTES) {
		mlog(0, "total payload len %zu too large\n", caller_bytes);
		ret = -EINVAL;
		goto out;
	}

	if (target_node == o2nm_this_node()) {
		ret = -ELOOP;
		goto out;
	}

	o2net_debug_add_nst(&nst);

	o2net_set_nst_sock_time(&nst);

	wait_event(nn->nn_sc_wq, o2net_tx_can_proceed(nn, &sc, &ret));
	o2net_set_nst_sock_container(&nst, sc);

	veclen = caller_veclen + 1;
	vec = kmalloc(sizeof(struct kvec) * veclen, GFP_ATOMIC);
	if (vec == NULL) {
		mlog(0, "failed to %zu element kvec!\n", veclen);
		ret = -ENOMEM;
		goto out;
	}

	msg = kmalloc(sizeof(struct o2net_msg), GFP_ATOMIC);
	if (!msg) {
		mlog(0, "failed to allocate a o2net_msg!\n");
		ret = -ENOMEM;
		goto out;
	}

	o2net_init_msg(msg, caller_bytes, msg_type, key);

	vec[0].iov_len = sizeof(struct o2net_msg);
	vec[0].iov_base = msg;
	memcpy(&vec[1], caller_vec, caller_veclen * sizeof(struct kvec));

	ret = o2net_prep_nsw(nn, &nsw);
	if (ret)
		goto out;

	msg->msg_num = cpu_to_be32(nsw.ns_id);
	o2net_set_nst_msg_id(&nst, nsw.ns_id);

	o2net_set_nst_send_time(&nst);

	/* finally, convert the message header to network byte-order
	 * and send */
	ret = o2net_send_tcp_msg(sc->sc_sock, vec, veclen,
				 sizeof(struct o2net_msg) + caller_bytes);
	mutex_unlock(&sc->sc_send_lock);
	msglog(msg, "sending returned %d\n", ret);
	if (ret < 0) {
		mlog(0, "error returned from o2net_send_tcp_msg=%d\n", ret);
		goto out;
	}

	/* wait on other node's handler */
	o2net_set_nst_status_time(&nst);
	wait_event(nsw.ns_wq, o2net_nsw_completed(nn, &nsw));

	o2net_update_send_stats(&nst, sc);

	/* Note that we avoid overwriting the callers status return
	 * variable if a system error was reported on the other
	 * side. Callers beware. */
	ret = o2net_sys_err_to_errno(nsw.ns_sys_status);
	if (status && !ret)
		*status = nsw.ns_status;

	mlog(0, "woken, returning system status %d, user status %d\n",
	     ret, nsw.ns_status);
out:
	o2net_debug_del_nst(&nst); /* must be before dropping sc and node */
	if (sc)
		sc_put(sc);
	if (vec)
		kfree(vec);
	if (msg)
		kfree(msg);
	o2net_complete_nsw(nn, &nsw, 0, 0, 0);
	return ret;
}
EXPORT_SYMBOL_GPL(o2net_send_message_vec);

int o2net_send_message(u32 msg_type, u32 key, void *data, u32 len,
		       u8 target_node, int *status)
{
	struct kvec vec = {
		.iov_base = data,
		.iov_len = len,
	};
	return o2net_send_message_vec(msg_type, key, &vec, 1,
				      target_node, status);
}
EXPORT_SYMBOL_GPL(o2net_send_message);

static int o2net_send_status_magic(struct socket *sock, struct o2net_msg *hdr,
				   enum o2net_system_error syserr, int err)
{
	struct kvec vec = {
		.iov_base = hdr,
		.iov_len = sizeof(struct o2net_msg),
	};

	BUG_ON(syserr >= O2NET_ERR_MAX);

	/* leave other fields intact from the incoming message, msg_num
	 * in particular */
	hdr->sys_status = cpu_to_be32(syserr);
	hdr->status = cpu_to_be32(err);
	hdr->magic = cpu_to_be16(O2NET_MSG_STATUS_MAGIC);  // twiddle the magic
	hdr->data_len = 0;

	msglog(hdr, "about to send status magic %d\n", err);
	/* hdr has been in host byteorder this whole time */
	return o2net_send_tcp_msg(sock, &vec, 1, sizeof(struct o2net_msg));
}

/* this returns -errno if the header was unknown or too large, etc.
 * after this is called the buffer us reused for the next message */
static int o2net_process_message(struct o2net_sock_container *sc,
				 struct o2net_msg *hdr)
{
	struct o2net_node *nn = o2net_nn_from_num(sc->sc_node->nd_num);
	int ret = 0, handler_status;
	enum  o2net_system_error syserr;
	struct o2net_msg_handler *nmh = NULL;

	msglog(hdr, "processing message\n");

	o2net_sc_postpone_idle(sc);

	switch(be16_to_cpu(hdr->magic)) {
		case O2NET_MSG_STATUS_MAGIC:
			/* special type for returning message status */
			o2net_complete_nsw(nn, NULL,
					   be32_to_cpu(hdr->msg_num),
					   be32_to_cpu(hdr->sys_status),
					   be32_to_cpu(hdr->status));
			goto out;
		case O2NET_MSG_KEEP_REQ_MAGIC:
			o2net_sendpage(sc, o2net_keep_resp,
				       sizeof(*o2net_keep_resp));
			goto out;
		case O2NET_MSG_KEEP_RESP_MAGIC:
			goto out;
		case O2NET_MSG_MAGIC:
			break;
		default:
			msglog(hdr, "bad magic\n");
			ret = -EINVAL;
			goto out;
			break;
	}

	/* find a handler for it */
	handler_status = 0;
	nmh = o2net_handler_get(be16_to_cpu(hdr->msg_type),
				be32_to_cpu(hdr->key));
	if (!nmh) {
		mlog(ML_TCP, "couldn't find handler for type %u key %08x\n",
		     be16_to_cpu(hdr->msg_type), be32_to_cpu(hdr->key));
		syserr = O2NET_ERR_NO_HNDLR;
		goto out_respond;
	}

	syserr = O2NET_ERR_NONE;

	if (be16_to_cpu(hdr->data_len) > nmh->nh_max_len)
		syserr = O2NET_ERR_OVERFLOW;

	if (syserr != O2NET_ERR_NONE)
		goto out_respond;

	sc->sc_msg_key = be32_to_cpu(hdr->key);
	sc->sc_msg_type = be16_to_cpu(hdr->msg_type);
	handler_status = (nmh->nh_func)(hdr, sizeof(struct o2net_msg) +
					     be16_to_cpu(hdr->data_len),
out_respond:
	/* this destroys the hdr, so don't use it after this */
	ret = o2net_send_status_magic(sc->sc_sock, hdr, syserr,
				      handler_status);
	mutex_unlock(&sc->sc_send_lock);
	hdr = NULL;
	mlog(0, "sending handler status %d, syserr %d returned %d\n",
	     handler_status, syserr, ret);

	if (nmh) {
		BUG_ON(ret_data != NULL && nmh->nh_post_func == NULL);
		if (nmh->nh_post_func)
			(nmh->nh_post_func)(handler_status, nmh->nh_func_data,
					    ret_data);
	}

out:
	if (nmh)
		o2net_handler_put(nmh);
	return ret;
}

static int o2net_check_handshake(struct o2net_sock_container *sc)
{
	struct o2net_handshake *hand = page_address(sc->sc_page);
	struct o2net_node *nn = o2net_nn_from_num(sc->sc_node->nd_num);

	if (hand->protocol_version != cpu_to_be64(O2NET_PROTOCOL_VERSION)) {
		printk(KERN_NOTICE "o2net: " SC_NODEF_FMT " Advertised net "
		       "protocol version %llu but %llu is required. "
		       "Disconnecting.\n", SC_NODEF_ARGS(sc),
		       (unsigned long long)be64_to_cpu(hand->protocol_version),
		       O2NET_PROTOCOL_VERSION);

		/* don't bother reconnecting if its the wrong version. */
		o2net_ensure_shutdown(nn, sc, -ENOTCONN);
		return -1;
	}

	/*
	 * Ensure timeouts are consistent with other nodes, otherwise
	 * we can end up with one node thinking that the other must be down,
	 * but isn't. This can ultimately cause corruption.
	 */
	if (be32_to_cpu(hand->o2net_idle_timeout_ms) !=
		printk(KERN_NOTICE "o2net: " SC_NODEF_FMT " uses a network "
		       "idle timeout of %u ms, but we use %u ms locally. "
		       "Disconnecting.\n", SC_NODEF_ARGS(sc),
		       be32_to_cpu(hand->o2net_idle_timeout_ms),
		       o2net_idle_timeout());
		o2net_ensure_shutdown(nn, sc, -ENOTCONN);
		return -1;
	}

	if (be32_to_cpu(hand->o2net_keepalive_delay_ms) !=
			o2net_keepalive_delay()) {
		printk(KERN_NOTICE "o2net: " SC_NODEF_FMT " uses a keepalive "
		       "delay of %u ms, but we use %u ms locally. "
		       "Disconnecting.\n", SC_NODEF_ARGS(sc),
		       be32_to_cpu(hand->o2net_keepalive_delay_ms),
		       o2net_keepalive_delay());
		o2net_ensure_shutdown(nn, sc, -ENOTCONN);
		return -1;
	}

	if (be32_to_cpu(hand->o2hb_heartbeat_timeout_ms) !=
			O2HB_MAX_WRITE_TIMEOUT_MS) {
		printk(KERN_NOTICE "o2net: " SC_NODEF_FMT " uses a heartbeat "
		       "timeout of %u ms, but we use %u ms locally. "
		       "Disconnecting.\n", SC_NODEF_ARGS(sc),
		       be32_to_cpu(hand->o2hb_heartbeat_timeout_ms),
		       O2HB_MAX_WRITE_TIMEOUT_MS);
		o2net_ensure_shutdown(nn, sc, -ENOTCONN);
		return -1;
	}

	sc->sc_handshake_ok = 1;

	spin_lock(&nn->nn_lock);
	/* set valid and queue the idle timers only if it hasn't been
	 * shut down already */
	if (nn->nn_sc == sc) {
		o2net_sc_reset_idle_timer(sc);
		atomic_set(&nn->nn_timeout, 0);
		o2net_set_nn_state(nn, sc, 1, 0);
	}
	spin_unlock(&nn->nn_lock);

	/* shift everything up as though it wasn't there */
	sc->sc_page_off -= sizeof(struct o2net_handshake);
	if (sc->sc_page_off)
		memmove(hand, hand + 1, sc->sc_page_off);

	return 0;
}

/* this demuxes the queued rx bytes into header or payload bits and calls
 * handlers as each full message is read off the socket.  it returns -error,
 * == 0 eof, or > 0 for progress made.*/
static int o2net_advance_rx(struct o2net_sock_container *sc)
{
	struct o2net_msg *hdr;
	int ret = 0;
	void *data;
	size_t datalen;

	sclog(sc, "receiving\n");
	if (unlikely(sc->sc_handshake_ok == 0)) {
		if(sc->sc_page_off < sizeof(struct o2net_handshake)) {
			data = page_address(sc->sc_page) + sc->sc_page_off;
			datalen = sizeof(struct o2net_handshake) - sc->sc_page_off;
			ret = o2net_recv_tcp_msg(sc->sc_sock, data, datalen);
			if (ret > 0)
				sc->sc_page_off += ret;
		}

		if (sc->sc_page_off == sizeof(struct o2net_handshake)) {
			o2net_check_handshake(sc);
			if (unlikely(sc->sc_handshake_ok == 0))
				ret = -EPROTO;
		}
		goto out;
	}

	/* do we need more header? */
	if (sc->sc_page_off < sizeof(struct o2net_msg)) {
		data = page_address(sc->sc_page) + sc->sc_page_off;
		datalen = sizeof(struct o2net_msg) - sc->sc_page_off;
		ret = o2net_recv_tcp_msg(sc->sc_sock, data, datalen);
		if (ret > 0) {
			sc->sc_page_off += ret;
			/* only swab incoming here.. we can
			 * only get here once as we cross from
			 * being under to over */
			if (sc->sc_page_off == sizeof(struct o2net_msg)) {
				hdr = page_address(sc->sc_page);
				if (be16_to_cpu(hdr->data_len) >
				    O2NET_MAX_PAYLOAD_BYTES)
					ret = -EOVERFLOW;
			}
		}
		if (ret <= 0)
			goto out;
	}

	if (sc->sc_page_off < sizeof(struct o2net_msg)) {
		/* oof, still don't have a header */
		goto out;
	}

	/* this was swabbed above when we first read it */
	hdr = page_address(sc->sc_page);

	msglog(hdr, "at page_off %zu\n", sc->sc_page_off);

	/* do we need more payload? */
	if (sc->sc_page_off - sizeof(struct o2net_msg) < be16_to_cpu(hdr->data_len)) {
		/* need more payload */
		data = page_address(sc->sc_page) + sc->sc_page_off;
		datalen = (sizeof(struct o2net_msg) + be16_to_cpu(hdr->data_len)) -
			  sc->sc_page_off;
		ret = o2net_recv_tcp_msg(sc->sc_sock, data, datalen);
		if (ret > 0)
			sc->sc_page_off += ret;
		if (ret <= 0)
			goto out;
	}

	if (sc->sc_page_off - sizeof(struct o2net_msg) == be16_to_cpu(hdr->data_len)) {
		/* we can only get here once, the first time we read
		 * the payload.. so set ret to progress if the handler
		 * works out. after calling this the message is toast */
		ret = o2net_process_message(sc, hdr);
		if (ret == 0)
			ret = 1;
		sc->sc_page_off = 0;
	}

out:
	sclog(sc, "ret = %d\n", ret);
	return ret;
}

/* this work func is triggerd by data ready.  it reads until it can read no
 * more.  it interprets 0, eof, as fatal.  if data_ready hits while we're doing
 * our work the work struct will be marked and we'll be called again. */
static void o2net_rx_until_empty(struct work_struct *work)
	struct o2net_sock_container *sc =
		container_of(work, struct o2net_sock_container, sc_rx_work);
	int ret;

	do {
		ret = o2net_advance_rx(sc);
	} while (ret > 0);

	if (ret <= 0 && ret != -EAGAIN) {
		struct o2net_node *nn = o2net_nn_from_num(sc->sc_node->nd_num);
		sclog(sc, "saw error %d, closing\n", ret);
		/* not permanent so read failed handshake can retry */
		o2net_ensure_shutdown(nn, sc, 0);
	}

	sc_put(sc);
}

static int o2net_set_nodelay(struct socket *sock)
{
	int ret, val = 1;
	mm_segment_t oldfs;

	oldfs = get_fs();
	set_fs(KERNEL_DS);

	/*
	 * Dear unsuspecting programmer,
	 *
	 * Don't use sock_setsockopt() for SOL_TCP.  It doesn't check its level
	 * argument and assumes SOL_SOCKET so, say, your TCP_NODELAY will
	 * silently turn into SO_DEBUG.
	 *
	 * Yours,
	 * Keeper of hilariously fragile interfaces.
	 */
	ret = sock->ops->setsockopt(sock, SOL_TCP, TCP_NODELAY,
				    (char __user *)&val, sizeof(val));

	set_fs(oldfs);
	return ret;
}

static void o2net_initialize_handshake(void)
{
	o2net_hand->o2hb_heartbeat_timeout_ms = cpu_to_be32(
		O2HB_MAX_WRITE_TIMEOUT_MS);
	o2net_hand->o2net_idle_timeout_ms = cpu_to_be32(o2net_idle_timeout());
	o2net_hand->o2net_keepalive_delay_ms = cpu_to_be32(
	o2net_hand->o2net_reconnect_delay_ms = cpu_to_be32(
/* ------------------------------------------------------------ */

/* called when a connect completes and after a sock is accepted.  the
 * rx path will see the response and mark the sc valid */
static void o2net_sc_connect_completed(struct work_struct *work)
	struct o2net_sock_container *sc =
		container_of(work, struct o2net_sock_container,
			     sc_connect_work);

	mlog(ML_MSG, "sc sending handshake with ver %llu id %llx\n",
              (unsigned long long)O2NET_PROTOCOL_VERSION,
	      (unsigned long long)be64_to_cpu(o2net_hand->connector_id));

	o2net_initialize_handshake();
	o2net_sendpage(sc, o2net_hand, sizeof(*o2net_hand));
	sc_put(sc);
}

/* this is called as a work_struct func. */
static void o2net_sc_send_keep_req(struct work_struct *work)
	struct o2net_sock_container *sc =
		container_of(work, struct o2net_sock_container,
			     sc_keepalive_work.work);

	o2net_sendpage(sc, o2net_keep_req, sizeof(*o2net_keep_req));
	sc_put(sc);
}

/* socket shutdown does a del_timer_sync against this as it tears down.
 * we can't start this timer until we've got to the point in sc buildup
 * where shutdown is going to be involved */
static void o2net_idle_timer(unsigned long data)
{
	struct o2net_sock_container *sc = (struct o2net_sock_container *)data;
	struct o2net_node *nn = o2net_nn_from_num(sc->sc_node->nd_num);
	unsigned long msecs = ktime_to_ms(ktime_get()) -
		ktime_to_ms(sc->sc_tv_timer);
#else
	unsigned long msecs = o2net_idle_timeout();
	printk(KERN_NOTICE "o2net: Connection to " SC_NODEF_FMT " has been "
	       "idle for %lu.%lu secs, shutting it down.\n", SC_NODEF_ARGS(sc),
	       msecs / 1000, msecs % 1000);
	/*
	 * Initialize the nn_timeout so that the next connection attempt
	 * will continue in o2net_start_connect.
	 */
	atomic_set(&nn->nn_timeout, 1);

	o2net_sc_queue_work(sc, &sc->sc_shutdown_work);
}

static void o2net_sc_reset_idle_timer(struct o2net_sock_container *sc)
{
	o2net_sc_cancel_delayed_work(sc, &sc->sc_keepalive_work);
	o2net_sc_queue_delayed_work(sc, &sc->sc_keepalive_work,
		      msecs_to_jiffies(o2net_keepalive_delay()));
	mod_timer(&sc->sc_idle_timeout,
	       jiffies + msecs_to_jiffies(o2net_idle_timeout()));
}

static void o2net_sc_postpone_idle(struct o2net_sock_container *sc)
{
	/* Only push out an existing timer */
	if (timer_pending(&sc->sc_idle_timeout))
		o2net_sc_reset_idle_timer(sc);
}

/* this work func is kicked whenever a path sets the nn state which doesn't
 * have valid set.  This includes seeing hb come up, losing a connection,
 * having a connect attempt fail, etc. This centralizes the logic which decides
 * if a connect attempt should be made or if we should give up and all future
 * transmit attempts should fail */
static void o2net_start_connect(struct work_struct *work)
	struct o2net_node *nn =
		container_of(work, struct o2net_node, nn_connect_work.work);
	struct o2net_sock_container *sc = NULL;
	struct o2nm_node *node = NULL, *mynode = NULL;
	struct socket *sock = NULL;
	struct sockaddr_in myaddr = {0, }, remoteaddr = {0, };
	int ret = 0, stop;
	unsigned int timeout;

	/* if we're greater we initiate tx, otherwise we accept */
	if (o2nm_this_node() <= o2net_num_from_nn(nn))
		goto out;

	/* watch for racing with tearing a node down */
	node = o2nm_get_node_by_num(o2net_num_from_nn(nn));
	if (node == NULL) {
		ret = 0;
		goto out;
	}

	mynode = o2nm_get_node_by_num(o2nm_this_node());
	if (mynode == NULL) {
		ret = 0;
		goto out;
	}

	spin_lock(&nn->nn_lock);
	/*
	 * see if we already have one pending or have given up.
	 * For nn_timeout, it is set when we close the connection
	 * because of the idle time out. So it means that we have
	 * at least connected to that node successfully once,
	 * now try to connect to it again.
	 */
	timeout = atomic_read(&nn->nn_timeout);
	stop = (nn->nn_sc ||
		(nn->nn_persistent_error &&
		(nn->nn_persistent_error != -ENOTCONN || timeout == 0)));
	spin_unlock(&nn->nn_lock);
	if (stop)
		goto out;

	nn->nn_last_connect_attempt = jiffies;

	sc = sc_alloc(node);
	if (sc == NULL) {
		mlog(0, "couldn't allocate sc\n");
		ret = -ENOMEM;
		goto out;
	}

	ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
	if (ret < 0) {
		mlog(0, "can't create socket: %d\n", ret);
		goto out;
	}
	sc->sc_sock = sock; /* freed by sc_kref_release */

	sock->sk->sk_allocation = GFP_ATOMIC;

	myaddr.sin_family = AF_INET;
	myaddr.sin_addr.s_addr = mynode->nd_ipv4_address;
	myaddr.sin_port = htons(0); /* any port */

	ret = sock->ops->bind(sock, (struct sockaddr *)&myaddr,
			      sizeof(myaddr));
	if (ret) {
Harvey Harrison's avatar
Harvey Harrison committed
		mlog(ML_ERROR, "bind failed with %d at address %pI4\n",
		     ret, &mynode->nd_ipv4_address);
		goto out;
	}

	ret = o2net_set_nodelay(sc->sc_sock);
	if (ret) {
		mlog(ML_ERROR, "setting TCP_NODELAY failed with %d\n", ret);
		goto out;
	}

	o2net_register_callbacks(sc->sc_sock->sk, sc);

	spin_lock(&nn->nn_lock);
	/* handshake completion will set nn->nn_sc_valid */
	o2net_set_nn_state(nn, sc, 0, 0);
	spin_unlock(&nn->nn_lock);

	remoteaddr.sin_family = AF_INET;
	remoteaddr.sin_addr.s_addr = node->nd_ipv4_address;
	remoteaddr.sin_port = node->nd_ipv4_port;

	ret = sc->sc_sock->ops->connect(sc->sc_sock,
					(struct sockaddr *)&remoteaddr,
					sizeof(remoteaddr),
					O_NONBLOCK);
	if (ret == -EINPROGRESS)
		ret = 0;

out:
	if (ret) {
		printk(KERN_NOTICE "o2net: Connect attempt to " SC_NODEF_FMT
		       " failed with errno %d\n", SC_NODEF_ARGS(sc), ret);
		/* 0 err so that another will be queued and attempted
		 * from set_nn_state */
		if (sc)
			o2net_ensure_shutdown(nn, sc, 0);
	}
	if (sc)
		sc_put(sc);
	if (node)
		o2nm_node_put(node);
static void o2net_connect_expired(struct work_struct *work)
	struct o2net_node *nn =
		container_of(work, struct o2net_node, nn_connect_expired.work);

	spin_lock(&nn->nn_lock);
	if (!nn->nn_sc_valid) {
		printk(KERN_NOTICE "o2net: No connection established with "
		       "node %u after %u.%u seconds, giving up.\n",
		     o2net_num_from_nn(nn),
		     o2net_idle_timeout() / 1000,
		     o2net_idle_timeout() % 1000);

		o2net_set_nn_state(nn, NULL, 0, -ENOTCONN);
	}
	spin_unlock(&nn->nn_lock);
}

static void o2net_still_up(struct work_struct *work)
	struct o2net_node *nn =
		container_of(work, struct o2net_node, nn_still_up.work);

	o2quo_hb_still_up(o2net_num_from_nn(nn));
}

/* ------------------------------------------------------------ */

void o2net_disconnect_node(struct o2nm_node *node)
{
	struct o2net_node *nn = o2net_nn_from_num(node->nd_num);

	/* don't reconnect until it's heartbeating again */
	spin_lock(&nn->nn_lock);
	atomic_set(&nn->nn_timeout, 0);
	o2net_set_nn_state(nn, NULL, 0, -ENOTCONN);
	spin_unlock(&nn->nn_lock);

	if (o2net_wq) {
		cancel_delayed_work(&nn->nn_connect_expired);
		cancel_delayed_work(&nn->nn_connect_work);
		cancel_delayed_work(&nn->nn_still_up);
		flush_workqueue(o2net_wq);
	}
}

static void o2net_hb_node_down_cb(struct o2nm_node *node, int node_num,
				  void *data)
{
	o2quo_hb_down(node_num);

Sunil Mushran's avatar
 
Sunil Mushran committed
	if (!node)
		return;

	if (node_num != o2nm_this_node())
		o2net_disconnect_node(node);

	BUG_ON(atomic_read(&o2net_connected_peers) < 0);
}

static void o2net_hb_node_up_cb(struct o2nm_node *node, int node_num,
				void *data)
{
	struct o2net_node *nn = o2net_nn_from_num(node_num);

	o2quo_hb_up(node_num);

Sunil Mushran's avatar
 
Sunil Mushran committed
	BUG_ON(!node);

	/* ensure an immediate connect attempt */
	nn->nn_last_connect_attempt = jiffies -
		(msecs_to_jiffies(o2net_reconnect_delay()) + 1);

	if (node_num != o2nm_this_node()) {
		/* believe it or not, accept and node hearbeating testing
		 * can succeed for this node before we got here.. so
		 * only use set_nn_state to clear the persistent error
		 * if that hasn't already happened */
		spin_lock(&nn->nn_lock);
		atomic_set(&nn->nn_timeout, 0);
		if (nn->nn_persistent_error)
			o2net_set_nn_state(nn, NULL, 0, 0);
		spin_unlock(&nn->nn_lock);
	}
}

void o2net_unregister_hb_callbacks(void)
{
	o2hb_unregister_callback(NULL, &o2net_hb_up);
	o2hb_unregister_callback(NULL, &o2net_hb_down);
}

int o2net_register_hb_callbacks(void)
{
	int ret;

	o2hb_setup_callback(&o2net_hb_down, O2HB_NODE_DOWN_CB,
			    o2net_hb_node_down_cb, NULL, O2NET_HB_PRI);
	o2hb_setup_callback(&o2net_hb_up, O2HB_NODE_UP_CB,
			    o2net_hb_node_up_cb, NULL, O2NET_HB_PRI);

	ret = o2hb_register_callback(NULL, &o2net_hb_up);
		ret = o2hb_register_callback(NULL, &o2net_hb_down);

	if (ret)
		o2net_unregister_hb_callbacks();

	return ret;
}

/* ------------------------------------------------------------ */

static int o2net_accept_one(struct socket *sock)
{
	int ret, slen;
	struct sockaddr_in sin;
	struct socket *new_sock = NULL;
	struct o2nm_node *node = NULL;
	struct o2nm_node *local_node = NULL;
	struct o2net_sock_container *sc = NULL;
	struct o2net_node *nn;

	BUG_ON(sock == NULL);
	ret = sock_create_lite(sock->sk->sk_family, sock->sk->sk_type,
			       sock->sk->sk_protocol, &new_sock);
	if (ret)
		goto out;

	new_sock->type = sock->type;
	new_sock->ops = sock->ops;
	ret = sock->ops->accept(sock, new_sock, O_NONBLOCK);
	if (ret < 0)
		goto out;

	new_sock->sk->sk_allocation = GFP_ATOMIC;

	ret = o2net_set_nodelay(new_sock);
	if (ret) {
		mlog(ML_ERROR, "setting TCP_NODELAY failed with %d\n", ret);
		goto out;
	}

	slen = sizeof(sin);
	ret = new_sock->ops->getname(new_sock, (struct sockaddr *) &sin,
				       &slen, 1);
	if (ret < 0)
		goto out;

	node = o2nm_get_node_by_ip(sin.sin_addr.s_addr);
	if (node == NULL) {
		printk(KERN_NOTICE "o2net: Attempt to connect from unknown "
		       "node at %pI4:%d\n", &sin.sin_addr.s_addr,
		       ntohs(sin.sin_port));
	if (o2nm_this_node() >= node->nd_num) {
		local_node = o2nm_get_node_by_num(o2nm_this_node());
		printk(KERN_NOTICE "o2net: Unexpected connect attempt seen "
		       "at node '%s' (%u, %pI4:%d) from node '%s' (%u, "
		       "%pI4:%d)\n", local_node->nd_name, local_node->nd_num,
		       &(local_node->nd_ipv4_address),
		       ntohs(local_node->nd_ipv4_port), node->nd_name,
		       node->nd_num, &sin.sin_addr.s_addr, ntohs(sin.sin_port));
		ret = -EINVAL;
		goto out;
	}

	/* this happens all the time when the other node sees our heartbeat
	 * and tries to connect before we see their heartbeat */
	if (!o2hb_check_node_heartbeating_from_callback(node->nd_num)) {
		mlog(ML_CONN, "attempt to connect from node '%s' at "
Harvey Harrison's avatar
Harvey Harrison committed
		     "%pI4:%d but it isn't heartbeating\n",
		     node->nd_name, &sin.sin_addr.s_addr,
		     ntohs(sin.sin_port));
		ret = -EINVAL;
		goto out;
	}

	nn = o2net_nn_from_num(node->nd_num);

	spin_lock(&nn->nn_lock);
	if (nn->nn_sc)
		ret = -EBUSY;
	else
		ret = 0;
	spin_unlock(&nn->nn_lock);
	if (ret) {
		printk(KERN_NOTICE "o2net: Attempt to connect from node '%s' "
		       "at %pI4:%d but it already has an open connection\n",
		       node->nd_name, &sin.sin_addr.s_addr,
		       ntohs(sin.sin_port));
		goto out;
	}

	sc = sc_alloc(node);
	if (sc == NULL) {
		ret = -ENOMEM;
		goto out;
	}

	sc->sc_sock = new_sock;
	new_sock = NULL;

	spin_lock(&nn->nn_lock);
	atomic_set(&nn->nn_timeout, 0);
	o2net_set_nn_state(nn, sc, 0, 0);
	spin_unlock(&nn->nn_lock);

	o2net_register_callbacks(sc->sc_sock->sk, sc);
	o2net_sc_queue_work(sc, &sc->sc_rx_work);

	o2net_initialize_handshake();
	o2net_sendpage(sc, o2net_hand, sizeof(*o2net_hand));

out:
	if (new_sock)
		sock_release(new_sock);
	if (node)
		o2nm_node_put(node);
	if (local_node)
		o2nm_node_put(local_node);
static void o2net_accept_many(struct work_struct *work)
	struct socket *sock = o2net_listen_sock;
	while (o2net_accept_one(sock) == 0)
		cond_resched();
}

static void o2net_listen_data_ready(struct sock *sk, int bytes)
{
	void (*ready)(struct sock *sk, int bytes);

	read_lock(&sk->sk_callback_lock);
	ready = sk->sk_user_data;
	if (ready == NULL) { /* check for teardown race */
		ready = sk->sk_data_ready;
		goto out;
	}

	/* ->sk_data_ready is also called for a newly established child socket
	 * before it has been accepted and the acceptor has set up their
	 * data_ready.. we only want to queue listen work for our listening
	 * socket */
	if (sk->sk_state == TCP_LISTEN) {
		mlog(ML_TCP, "bytes: %d\n", bytes);
		queue_work(o2net_wq, &o2net_listen_work);
	}

out:
	read_unlock(&sk->sk_callback_lock);
	ready(sk, bytes);
}

static int o2net_open_listening_sock(__be32 addr, __be16 port)
{
	struct socket *sock = NULL;
	int ret;
	struct sockaddr_in sin = {
		.sin_family = PF_INET,
		.sin_addr = { .s_addr = addr },
		.sin_port = port,
	};

	ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
	if (ret < 0) {
		printk(KERN_ERR "o2net: Error %d while creating socket\n", ret);
		goto out;
	}

	sock->sk->sk_allocation = GFP_ATOMIC;

	write_lock_bh(&sock->sk->sk_callback_lock);
	sock->sk->sk_user_data = sock->sk->sk_data_ready;