Socket 死连接详解

当使用 Socket 进行通信时,由于各种不同的因素,都有可能导致死连接停留在服务器端,假如服务端需要处理的连接较多,就有可能造成服务器资源严重浪费,对此,本文将阐述其原理以及解决方法。

在写 Socket 进行通讯时,我们必须预料到各种可能发生的情况并对其进行处理,通常情况下,有以下两种情况可能造成死连接:

  • 通讯程序编写不完善
  • 网络/硬件故障

a) 通讯程序编写不完善

这里要指出的一点就是,绝大多数程序都是由于程序编写不完善所造成的死连接,即对 Socket 未能进行完善的管理,导致占用端口导致服务器资源耗尽。当然,很多情况下,程序可能不是我们所写,而由于程序代码的复杂、杂乱等原因所导致难以维护也是我们所需要面对的。

网上有很多文章都提到 Socket 长时间处于 CLOSE_WAIT 状态下的问题,说可以使用 Keepalive 选项设置 TCP 心跳来解决,但是却发现设置选项后未能收到效果 。

因此,这里我分享出自己的解决方案:

Windows 中对于枚举系统网络连接有一些非常方便的 API:

  • GetTcpTable : 获得 TCP 连接表
  • GetExtendedTcpTable : 获得扩展后的 TCP 连接表,相比 GetTcpTable 更为强大,可以获取与连接的进程 ID
  • SetTcpEntry : 设置 TCP 连接状态,但据 MSDN 所述,只能设置状态为 DeleteTcb,即删除连接

相信大多数朋友看到这些 API ,就已经了解到我们下一步要做什么了;枚举所有 TCP 连接,筛选出本进程的连接,最后判断是否 CLOSE_WAIT 状态,如果是,则使用 SetTcpEntry 关闭。

其实 Sysinternal 的 TcpView 工具也是应用上述 API 实现其功能的,此工具为我常用的网络诊断工具,同时也可作为一个简单的手动式网络防火墙。

下面来看 Zealic 封装后的代码:


TcpManager.cs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/**
<code>
  <revsion>$Rev: 0 $</revision>
  <owner name="Zealic" mail="rszealic(at)gmail.com" />
</code>
**/
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.NetworkInformation;
using System.Runtime.InteropServices;


namespace Zealic.Network
{
    /// <summary>
    /// TCP 管理器
    /// </summary>
    public static class TcpManager
    {
        #region PInvoke define
        private const int TCP_TABLE_OWNER_PID_ALL = 5;

        [DllImport("iphlpapi.dll", SetLastError = true)]
        private static extern uint GetExtendedTcpTable(
            IntPtr pTcpTable, ref int dwOutBufLen, bool sort, int ipVersion, int tblClass, int reserved);

        [DllImport("iphlpapi.dll")]
        private static extern int SetTcpEntry(ref MIB_TCPROW pTcpRow);


        [StructLayout(LayoutKind.Sequential)]
        private struct MIB_TCPROW
        {
            public TcpState dwState;
            public int dwLocalAddr;
            public int dwLocalPort;
            public int dwRemoteAddr;
            public int dwRemotePort;
        }

        [StructLayout(LayoutKind.Sequential)]
        private struct MIB_TCPROW_OWNER_PID
        {
            public TcpState dwState;
            public uint dwLocalAddr;
            public int dwLocalPort;
            public uint dwRemoteAddr;
            public int dwRemotePort;
            public int dwOwningPid;
        }

        [StructLayout(LayoutKind.Sequential)]
        private struct MIB_TCPTABLE_OWNER_PID
        {
            public uint dwNumEntries;
            private MIB_TCPROW_OWNER_PID table;
        }
        #endregion

        private static MIB_TCPROW_OWNER_PID[] GetAllTcpConnections()
        {
            const int NO_ERROR = 0;
            const int IP_v4 = 2;
            MIB_TCPROW_OWNER_PID[] tTable = null;
            int buffSize = 0;
            GetExtendedTcpTable(IntPtr.Zero, ref buffSize, true, IP_v4, TCP_TABLE_OWNER_PID_ALL, 0);
            IntPtr buffTable = Marshal.AllocHGlobal(buffSize);
            try
            {
                if (NO_ERROR != GetExtendedTcpTable(buffTable, ref buffSize, true, IP_v4, TCP_TABLE_OWNER_PID_ALL, 0)) return null;
                MIB_TCPTABLE_OWNER_PID tab =
                    (MIB_TCPTABLE_OWNER_PID)Marshal.PtrToStructure(buffTable, typeof(MIB_TCPTABLE_OWNER_PID));
                IntPtr rowPtr = (IntPtr)((long)buffTable + Marshal.SizeOf(tab.dwNumEntries));
                tTable = new MIB_TCPROW_OWNER_PID[tab.dwNumEntries];

                int rowSize = Marshal.SizeOf(typeof(MIB_TCPROW_OWNER_PID));
                for (int i = 0; i < tab.dwNumEntries; i++)
                {
                    MIB_TCPROW_OWNER_PID tcpRow =
                        (MIB_TCPROW_OWNER_PID)Marshal.PtrToStructure(rowPtr, typeof(MIB_TCPROW_OWNER_PID));
                    tTable[i] = tcpRow;
                    rowPtr = (IntPtr)((int)rowPtr + rowSize);
                }
            }
            finally
            {
                Marshal.FreeHGlobal(buffTable);
            }
            return tTable;
        }

        private static int TranslatePort(int port)
        {
            return ((port & 0xFF) << 8 | (port & 0xFF00) >> 8);
        }

        public static bool Kill(TcpConnectionInfo conn)
        {
            if (conn == null) throw new ArgumentNullException("conn");
            MIB_TCPROW row = new MIB_TCPROW();
            row.dwState = TcpState.DeleteTcb;
#pragma warning disable 612,618
            row.dwLocalAddr = (int)conn.LocalEndPoint.Address.Address;
#pragma warning restore 612,618
            row.dwLocalPort = TranslatePort(conn.LocalEndPoint.Port);
#pragma warning disable 612,618
            row.dwRemoteAddr = (int)conn.RemoteEndPoint.Address.Address;
#pragma warning restore 612,618
            row.dwRemotePort = TranslatePort(conn.RemoteEndPoint.Port);
            return SetTcpEntry(ref row) == 0;
        }

        public static bool Kill(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint)
        {
            if (localEndPoint == null) throw new ArgumentNullException("localEndPoint");
            if (remoteEndPoint == null) throw new ArgumentNullException("remoteEndPoint");
            MIB_TCPROW row = new MIB_TCPROW();
            row.dwState = TcpState.DeleteTcb;
#pragma warning disable 612,618
            row.dwLocalAddr = (int)localEndPoint.Address.Address;
#pragma warning restore 612,618
            row.dwLocalPort = TranslatePort(localEndPoint.Port);
#pragma warning disable 612,618
            row.dwRemoteAddr = (int)remoteEndPoint.Address.Address;
#pragma warning restore 612,618
            row.dwRemotePort = TranslatePort(remoteEndPoint.Port);
            return SetTcpEntry(ref row) == 0;
        }


        public static TcpConnectionInfo[] GetTableByProcess(int pid)
        {
            MIB_TCPROW_OWNER_PID[] tcpRows = GetAllTcpConnections();
            if (tcpRows == null) return null;
            List<TcpConnectionInfo> list = new List<TcpConnectionInfo>();
            foreach (MIB_TCPROW_OWNER_PID row in tcpRows)
            {
                if (row.dwOwningPid == pid)
                {
                    int localPort = TranslatePort(row.dwLocalPort);
                    int remotePort = TranslatePort(row.dwRemotePort);
                    TcpConnectionInfo conn =
                        new TcpConnectionInfo(
                            new IPEndPoint(row.dwLocalAddr, localPort),
                            new IPEndPoint(row.dwRemoteAddr, remotePort),
                            row.dwState);
                    list.Add(conn);
                }
            }
            return list.ToArray();
        }

        public static TcpConnectionInfo[] GetTalbeByCurrentProcess()
        {
            return GetTableByProcess(Process.GetCurrentProcess().Id);
        }

    }
}

TcpConnectionInfo.cs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/**
<code>
  <revsion>$Rev: 608 $</revision>
  <owner name="Zealic" mail="rszealic(at)gmail.com" />
</code>
**/
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.NetworkInformation;


namespace Zealic.Network
{
    /// <summary>
    /// TCP 连接信息
    /// </summary>
    public sealed class TcpConnectionInfo : IEquatable<TcpConnectionInfo>, IEqualityComparer<TcpConnectionInfo>
    {
        private readonly IPEndPoint _LocalEndPoint;
        private readonly IPEndPoint _RemoteEndPoint;
        private readonly TcpState _State;

        public TcpConnectionInfo(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, TcpState state)
        {
            if (localEndPoint == null) throw new ArgumentNullException("localEndPoint");
            if (remoteEndPoint == null) throw new ArgumentNullException("remoteEndPoint");
            _LocalEndPoint = localEndPoint;
            _RemoteEndPoint = remoteEndPoint;
            _State = state;
        }

        public IPEndPoint LocalEndPoint
        {
            get { return _LocalEndPoint; }
        }

        public IPEndPoint RemoteEndPoint
        {
            get { return _RemoteEndPoint; }
        }

        public TcpState State
        {
            get { return _State; }
        }

        public bool Equals(TcpConnectionInfo x, TcpConnectionInfo y)
        {
            return (x.LocalEndPoint.Equals(y.LocalEndPoint) && x.RemoteEndPoint.Equals(y.RemoteEndPoint));
        }

        public int GetHashCode(TcpConnectionInfo obj)
        {
            return obj.LocalEndPoint.GetHashCode() ^ obj.RemoteEndPoint.GetHashCode();
        }

        public bool Equals(TcpConnectionInfo other)
        {
            return Equals(this, other);
        }

        public override bool Equals(object obj)
        {
            if (obj == null || !(obj is TcpConnectionInfo))
                return false;
            return Equals(this, (TcpConnectionInfo)obj);
        }

    }
}

至此,我们可以通过 TcpManager 类的 GetTableByProcess 方法获取进程中所有的 TCP 连接信息,然后通过 Kill 方法强制关连接以回收系统资源,虽然很C很GX,但是很有效。

通常情况下,我们可以使用 Timer 来定时检测进程中的 TCP 连接状态,确定其是否处于 CLOSE_WAIT 状态,当超过指定的次数/时间时,就把它干掉。

不过,相对这样的解决方法,我还是推荐在设计 Socket 服务端程序的时候,一定要管理所有的连接,而非上述方法。

b) 网络/硬件故障

现在我们再来看第二种情况,当网络/硬件故障时,如何应对;与上面不同,这样的情况 TCP 可能处于 ESTABLISHED、CLOSE_WAIT、FIN_WAIT 等状态中的任何一种,这时才是 Keepalive 该出马的时候。

默认情况下 Keepalive 的时间设置为两小时,如果是请求比较多的服务端程序,两小时未免太过漫长,等到它时间到,估计连黄花菜都凉了,好在我们可以通过 Socket.IOControl 方法手动设置其属性,以达到我们的目的。

关键代码如下:

1
2
3
4
5
6
7
8
// 假设 accepted 到的 Socket 为变量 client
// ...
// 设置 TCP 心跳,空闲 15 秒,每 5 秒检查一次
byte[] inOptionValues = new byte[4 * 3];
BitConverter.GetBytes((uint)1).CopyTo(inOptionValues, 0);
BitConverter.GetBytes((uint)15000).CopyTo(inOptionValues, 4);
BitConverter.GetBytes((uint)5000).CopyTo(inOptionValues, 8);
client.IOControl(IOControlCode.KeepAliveValues, inOptionValues, null);

以上代码的作用就是设置 TCP 心跳为 5 秒,当三次检测到无法与客户端连接后,将会关闭 Socket。

相信上述代码加上说明,对于有一定基础读者理解起来应该不难,今天到此为止。

c) 结束语

其实对于 Socket 程序设计来说,良好的通信协议才是稳定的保证,类似于这样的问题,如果在应用程序通信协议中加入自己的心跳包,不仅可以处理多种棘手的问题,还可以在心跳中加入自己的简单校验功能,防止包数据被 WPE 等软件篡改。但是,很多情况下这些都不是我们所能决定的,因此,才有了本文中提出的方法。

警告 :本文系 Zealic 创作,并基于 CC 3.0 共享创作许可协议 发布,如果您转载此文或使用其中的代码,请务必先阅读协议内容。

Zealic @ 2008-03-15

View Comments |
Categories: tech
Tags:

Related posts