We have different versions of this function for IPv4 and IPv6, but the caller already requires some IP version specific code to get the right header pointers. Instead, have a common function that fills either an IPv4 or an IPv6 header based on which header pointer it is passed. This allows us to remove a small amount of code duplication and make a few slightly ugly conditionals. Signed-off-by: David Gibson <david(a)gibson.dropbear.id.au> --- tcp.c | 114 +++++++++++++++++++++---------------------------- tcp_buf.c | 24 +++++------ tcp_internal.h | 16 +++---- tcp_vu.c | 36 ++++++---------- 4 files changed, 79 insertions(+), 111 deletions(-) diff --git a/tcp.c b/tcp.c index 297eb8c..422ecbb 100644 --- a/tcp.c +++ b/tcp.c @@ -882,102 +882,86 @@ static void tcp_fill_header(struct tcphdr *th, } /** - * tcp_fill_headers4() - Fill 802.3, IPv4, TCP headers in pre-cooked buffers + * tcp_fill_headers() - Fill 802.3, IP, TCP headers * @conn: Connection pointer * @taph: tap backend specific header - * @iph: Pointer to IPv4 header + * @ip4h: Pointer to IPv4 header, or NULL + * @ip6h: Pointer to IPv6 header, or NULL * @th: Pointer to TCP header * @iov: IO vector containing payload * @iov_cnt: Number of entries in @iov * @doffset: Offset of the TCP payload within @iov - * @check: Checksum, if already known + * @ip4_check: IPv4 checksum, if already known * @seq: Sequence number for this segment * @no_tcp_csum: Do not set TCP checksum */ -void tcp_fill_headers4(const struct tcp_tap_conn *conn, - struct tap_hdr *taph, struct iphdr *iph, - struct tcphdr *th, - const struct iovec *iov, size_t iov_cnt, size_t doffset, - const uint16_t *check, uint32_t seq, bool no_tcp_csum) +void tcp_fill_headers(const struct tcp_tap_conn *conn, + struct tap_hdr *taph, + struct iphdr *ip4h, struct ipv6hdr *ip6h, + struct tcphdr *th, + const struct iovec *iov, size_t iov_cnt, size_t doffset, + const uint16_t *ip4_check, uint32_t seq, bool no_tcp_csum) { const struct flowside *tapside = TAPFLOW(conn); - const struct in_addr *src4 = inany_v4(&tapside->oaddr); - const struct in_addr *dst4 = inany_v4(&tapside->eaddr); size_t dlen = iov_size(iov, iov_cnt) - doffset; size_t l4len = dlen + sizeof(*th); - size_t l3len = l4len + sizeof(*iph); + size_t l3len = l4len; + uint32_t psum = 0; - ASSERT(src4 && dst4); + if (ip4h) { + const struct in_addr *src4 = inany_v4(&tapside->oaddr); + const struct in_addr *dst4 = inany_v4(&tapside->eaddr); - iph->tot_len = htons(l3len); - iph->saddr = src4->s_addr; - iph->daddr = dst4->s_addr; + ASSERT(src4 && dst4); - iph->check = check ? *check : - csum_ip4_header(l3len, IPPROTO_TCP, *src4, *dst4); + l3len += + sizeof(*ip4h); - tcp_fill_header(th, conn, seq); + ip4h->tot_len = htons(l3len); + ip4h->saddr = src4->s_addr; + ip4h->daddr = dst4->s_addr; - if (no_tcp_csum) { - th->check = 0; - } else { - uint32_t psum = proto_ipv4_header_psum(l4len, IPPROTO_TCP, - *src4, *dst4); + if (ip4_check) + ip4h->check = *ip4_check; + else + ip4h->check = csum_ip4_header(l3len, IPPROTO_TCP, + *src4, *dst4); - tcp_update_csum(psum, th, iov, iov_cnt, doffset); + if (!no_tcp_csum) { + psum = proto_ipv4_header_psum(l4len, IPPROTO_TCP, + *src4, *dst4); + } } - tap_hdr_update(taph, l3len + sizeof(struct ethhdr)); -} + if (ip6h) { + l3len += sizeof(*ip6h); -/** - * tcp_fill_headers6() - Fill 802.3, IPv6, TCP headers in pre-cooked buffers - * @conn: Connection pointer - * @taph: tap backend specific header - * @ip6h: Pointer to IPv6 header - * @th: Pointer to TCP header - * @iov: IO vector containing payload - * @iov_cnt: Number of entries in @iov - * @doffset: Offset of the TCP payload within @iov - * @check: Checksum, if already known - * @seq: Sequence number for this segment - * @no_tcp_csum: Do not set TCP checksum - */ -void tcp_fill_headers6(const struct tcp_tap_conn *conn, - struct tap_hdr *taph, struct ipv6hdr *ip6h, - struct tcphdr *th, - const struct iovec *iov, size_t iov_cnt, size_t doffset, - uint32_t seq, bool no_tcp_csum) -{ - const struct flowside *tapside = TAPFLOW(conn); - size_t dlen = iov_size(iov, iov_cnt) - doffset; - size_t l4len = dlen + sizeof(*th); + ip6h->payload_len = htons(l4len); + ip6h->saddr = tapside->oaddr.a6; + ip6h->daddr = tapside->eaddr.a6; - ip6h->payload_len = htons(l4len); - ip6h->saddr = tapside->oaddr.a6; - ip6h->daddr = tapside->eaddr.a6; + ip6h->hop_limit = 255; + ip6h->version = 6; + ip6h->nexthdr = IPPROTO_TCP; - ip6h->hop_limit = 255; - ip6h->version = 6; - ip6h->nexthdr = IPPROTO_TCP; + ip6h->flow_lbl[0] = (conn->sock >> 16) & 0xf; + ip6h->flow_lbl[1] = (conn->sock >> 8) & 0xff; + ip6h->flow_lbl[2] = (conn->sock >> 0) & 0xff; - ip6h->flow_lbl[0] = (conn->sock >> 16) & 0xf; - ip6h->flow_lbl[1] = (conn->sock >> 8) & 0xff; - ip6h->flow_lbl[2] = (conn->sock >> 0) & 0xff; + if (!no_tcp_csum) { + psum = proto_ipv6_header_psum(l4len, IPPROTO_TCP, + &ip6h->saddr, + &ip6h->daddr); + } + } tcp_fill_header(th, conn, seq); - if (no_tcp_csum) { + if (no_tcp_csum) th->check = 0; - } else { - uint32_t psum = proto_ipv6_header_psum(l4len, IPPROTO_TCP, - &ip6h->saddr, - &ip6h->daddr); - + else tcp_update_csum(psum, th, iov, iov_cnt, doffset); - } - tap_hdr_update(taph, l4len + sizeof(*ip6h) + sizeof(struct ethhdr)); + tap_hdr_update(taph, l3len + sizeof(struct ethhdr)); } /** diff --git a/tcp_buf.c b/tcp_buf.c index 0e6b67d..61eaf4e 100644 --- a/tcp_buf.c +++ b/tcp_buf.c @@ -198,23 +198,21 @@ static void tcp_l2_buf_fill_headers(const struct tcp_tap_conn *conn, struct iovec *iov, const uint16_t *check, uint32_t seq, bool no_tcp_csum) { + struct tcphdr *th = iov[TCP_IOV_PAYLOAD].iov_base; + struct tap_hdr *taph = iov[TCP_IOV_TAP].iov_base; const struct iovec *tail = &iov[TCP_IOV_PAYLOAD]; const struct flowside *tapside = TAPFLOW(conn); const struct in_addr *a4 = inany_v4(&tapside->oaddr); + struct ipv6hdr *ip6h = NULL; + struct iphdr *ip4h = NULL; - if (a4) { - tcp_fill_headers4(conn, iov[TCP_IOV_TAP].iov_base, - iov[TCP_IOV_IP].iov_base, - iov[TCP_IOV_PAYLOAD].iov_base, - tail, 1, sizeof(struct tcphdr), - check, seq, no_tcp_csum); - } else { - tcp_fill_headers6(conn, iov[TCP_IOV_TAP].iov_base, - iov[TCP_IOV_IP].iov_base, - iov[TCP_IOV_PAYLOAD].iov_base, - tail, 1, sizeof(struct tcphdr), - seq, no_tcp_csum); - } + if (a4) + ip4h = iov[TCP_IOV_IP].iov_base; + else + ip6h = iov[TCP_IOV_IP].iov_base; + + tcp_fill_headers(conn, taph, ip4h, ip6h, th, tail, 1, sizeof(*th), + check, seq, no_tcp_csum); } /** diff --git a/tcp_internal.h b/tcp_internal.h index a2de15a..5e5a794 100644 --- a/tcp_internal.h +++ b/tcp_internal.h @@ -180,16 +180,12 @@ struct tcp_info_linux; void tcp_update_csum(uint32_t psum, struct tcphdr *th, const struct iovec *iov, int iov_cnt, size_t doffset); -void tcp_fill_headers4(const struct tcp_tap_conn *conn, - struct tap_hdr *taph, struct iphdr *iph, - struct tcphdr *th, - const struct iovec *iov, size_t iov_cnt, size_t doffset, - const uint16_t *check, uint32_t seq, bool no_tcp_csum); -void tcp_fill_headers6(const struct tcp_tap_conn *conn, - struct tap_hdr *taph, struct ipv6hdr *ip6h, - struct tcphdr *th, - const struct iovec *iov, size_t iov_cnt, size_t doffset, - uint32_t seq, bool no_tcp_csum); +void tcp_fill_headers(const struct tcp_tap_conn *conn, + struct tap_hdr *taph, + struct iphdr *ip4h, struct ipv6hdr *ip6h, + struct tcphdr *th, + const struct iovec *iov, size_t iov_cnt, size_t doffset, + const uint16_t *ip4_check, uint32_t seq, bool no_tcp_csum); int tcp_update_seqack_wnd(const struct ctx *c, struct tcp_tap_conn *conn, bool force_seq, struct tcp_info_linux *tinfo); diff --git a/tcp_vu.c b/tcp_vu.c index 916e35d..5d9f73f 100644 --- a/tcp_vu.c +++ b/tcp_vu.c @@ -104,8 +104,8 @@ int tcp_vu_send_flag(const struct ctx *c, struct tcp_tap_conn *conn, int flags) const struct flowside *tapside = TAPFLOW(conn); size_t l2len, optlen, hdrlen; struct ipv6hdr *ip6h = NULL; + struct iphdr *ip4h = NULL; struct tcp_syn_opts *opts; - struct iphdr *iph = NULL; struct tcphdr *th; struct ethhdr *eh; uint32_t seq; @@ -133,8 +133,8 @@ int tcp_vu_send_flag(const struct ctx *c, struct tcp_tap_conn *conn, int flags) if (CONN_V4(conn)) { eh->h_proto = htons(ETH_P_IP); - iph = vu_ip(iov_vu[0].iov_base); - *iph = (struct iphdr)L2_BUF_IP4_INIT(IPPROTO_TCP); + ip4h = vu_ip(iov_vu[0].iov_base); + *ip4h = (struct iphdr)L2_BUF_IP4_INIT(IPPROTO_TCP); th = vu_payloadv4(iov_vu[0].iov_base); } else { @@ -162,15 +162,9 @@ int tcp_vu_send_flag(const struct ctx *c, struct tcp_tap_conn *conn, int flags) elem[0].in_sg[0].iov_len = l2len + sizeof(struct virtio_net_hdr_mrg_rxbuf); - if (CONN_V4(conn)) { - tcp_fill_headers4(conn, NULL, iph, th, iov_vu, 1, - (char *)opts - (char *)iov_vu[0].iov_base, - NULL, seq, true); - } else { - tcp_fill_headers6(conn, NULL, ip6h, th, iov_vu, 1, - (char *)opts - (char *)iov_vu[0].iov_base, - seq, true); - } + tcp_fill_headers(conn, NULL, ip4h, ip6h, th, iov_vu, 1, + (char *)opts - (char *)iov_vu[0].iov_base, + NULL, seq, true); if (*c->pcap) { tcp_vu_update_check(tapside, &elem[0].in_sg[0], 1); @@ -283,7 +277,7 @@ static void tcp_vu_prepare(const struct ctx *c, struct tcp_tap_conn *conn, const struct flowside *toside = TAPFLOW(conn); char *base = iov[0].iov_base; struct ipv6hdr *ip6h = NULL; - struct iphdr *iph = NULL; + struct iphdr *ip4h = NULL; struct tcphdr *th; struct ethhdr *eh; char *data; @@ -306,8 +300,8 @@ static void tcp_vu_prepare(const struct ctx *c, struct tcp_tap_conn *conn, eh->h_proto = htons(ETH_P_IP); - iph = vu_ip(base); - *iph = (struct iphdr)L2_BUF_IP4_INIT(IPPROTO_TCP); + ip4h = vu_ip(base); + *ip4h = (struct iphdr)L2_BUF_IP4_INIT(IPPROTO_TCP); th = vu_payloadv4(base); } else { ASSERT(iov[0].iov_len >= sizeof(struct virtio_net_hdr_mrg_rxbuf) + @@ -327,14 +321,10 @@ static void tcp_vu_prepare(const struct ctx *c, struct tcp_tap_conn *conn, th->ack = 1; data = (char *)(th + 1); - if (inany_v4(&toside->eaddr) && inany_v4(&toside->oaddr)) { - tcp_fill_headers4(conn, NULL, iph, th, iov, iov_cnt, data - base, - *check, conn->seq_to_tap, true); - *check = &iph->check; - } else { - tcp_fill_headers6(conn, NULL, ip6h, th, iov, iov_cnt, data - base, - conn->seq_to_tap, true); - } + tcp_fill_headers(conn, NULL, ip4h, ip6h, th, iov, iov_cnt, data - base, + *check, conn->seq_to_tap, true); + if (ip4h) + *check = &ip4h->check; } /** -- 2.47.0