1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
|
/* SPDX-License-Identifier: GPL-2.0-only */
#ifndef __NET_PSP_HELPERS_H
#define __NET_PSP_HELPERS_H
#include <linux/skbuff.h>
#include <linux/rcupdate.h>
#include <linux/udp.h>
#include <net/sock.h>
#include <net/tcp.h>
#include <net/psp/types.h>
struct inet_timewait_sock;
/* Driver-facing API */
struct psp_dev *
psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
struct psp_dev_caps *psd_caps, void *priv_ptr);
void psp_dev_unregister(struct psp_dev *psd);
bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
u8 ver, __be16 sport);
int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
/* Kernel-facing API */
void psp_assoc_put(struct psp_assoc *pas);
static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
{
return pas->drv_data;
}
#if IS_ENABLED(CONFIG_INET_PSP)
unsigned int psp_key_size(u32 version);
void psp_sk_assoc_free(struct sock *sk);
void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
void psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb);
static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
{
return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
}
static inline void
psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
{
struct psp_assoc *pas;
pas = psp_sk_assoc(sk);
if (pas && pas->tx.spi)
skb->decrypted = 1;
}
static inline unsigned long
__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
unsigned long diffs)
{
struct psp_skb_ext *a, *b;
a = skb_ext_find(one, SKB_EXT_PSP);
b = skb_ext_find(two, SKB_EXT_PSP);
diffs |= (!!a) ^ (!!b);
if (!diffs && unlikely(a))
diffs |= memcmp(a, b, sizeof(*a));
return diffs;
}
static inline bool
psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
{
bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
u32 end_seq = TCP_SKB_CB(skb)->end_seq;
u32 seq = TCP_SKB_CB(skb)->seq;
bool pure_fin;
pure_fin = fin && end_seq - seq == 1;
return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
}
static inline bool
psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
{
return pse && pas->rx.spi == pse->spi &&
pas->generation == pse->generation &&
pas->version == pse->version &&
pas->dev_id == pse->dev_id;
}
static inline enum skb_drop_reason
__psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
{
struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
if (!pas)
return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
if (likely(psp_pse_matches_pas(pse, pas))) {
if (unlikely(!pas->peer_tx))
pas->peer_tx = 1;
return 0;
}
if (!pse) {
if (!pas->tx.spi ||
(!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
return 0;
}
return SKB_DROP_REASON_PSP_INPUT;
}
static inline enum skb_drop_reason
psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
{
return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
}
static inline enum skb_drop_reason
psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
{
return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
}
static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk)
{
struct psp_assoc *pas;
int state;
state = READ_ONCE(sk->sk_state);
if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV)
return NULL;
pas = state == TCP_TIME_WAIT ?
rcu_dereference(inet_twsk(sk)->psp_assoc) :
rcu_dereference(sk->psp_assoc);
return pas;
}
static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
{
if (!skb->decrypted || !skb->sk)
return NULL;
return psp_sk_get_assoc_rcu(skb->sk);
}
static inline unsigned int psp_sk_overhead(const struct sock *sk)
{
int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
bool has_psp = rcu_access_pointer(sk->psp_assoc);
return has_psp ? psp_encap : 0;
}
#else
static inline void psp_sk_assoc_free(struct sock *sk) { }
static inline void
psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
static inline void
psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb) { }
static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
{
return NULL;
}
static inline void
psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
static inline unsigned long
__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
unsigned long diffs)
{
return diffs;
}
static inline enum skb_drop_reason
psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
{
return 0;
}
static inline enum skb_drop_reason
psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
{
return 0;
}
static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
{
return NULL;
}
static inline unsigned int psp_sk_overhead(const struct sock *sk)
{
return 0;
}
#endif
static inline unsigned long
psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
{
return __psp_skb_coalesce_diff(one, two, 0);
}
#endif /* __NET_PSP_HELPERS_H */
|