1 module socks.socks5;
2 
3 import std.socket : AddressFamily;
4 
5 enum Socks5Version = 0x05;
6 
7 enum AuthMethod: ubyte
8 {
9     NOAUTH = 0x00,
10     AUTH = 0x02,
11     NOTAVAILABLE = 0xFF
12 }
13 
14 enum AuthStatus: ubyte
15 {
16     YES = 0x00,
17     NO = 0x01
18 }
19 
20 enum RequestCmd: ubyte
21 {
22     CONNECT = 0x01,
23     BIND = 0x02,
24     UDPASSOCIATE = 0x03,
25 }
26 
27 enum AddressType: ubyte
28 {
29     IPV4 = 0x01,
30     DOMAIN = 0x03,
31     IPV6 = 0x04,
32 }
33 
34 enum ReplyCode: ubyte
35 {
36     SUCCEEDED = 0x00,
37     FAILURE = 0x01,
38     NOTALLOWED = 0x02,
39     NETWORK_UNREACHABLE = 0x03,
40     HOST_UNREACHABLE = 0x04,
41     CONNECTION_REFUSED = 0x05,
42     TTL_EXPIRED = 0x06,
43     CMD_NOTSUPPORTED = 0x07,
44     ADDR_NOTSUPPORTED = 0x08,
45 
46     UNKNOWN = 0xff,
47 }
48 
49 enum isSocksOptions(T) =
50     (is(T == Socks5Options));
51 
52 struct Socks5Options
53 {
54     AuthMethod[] authMethods = [ AuthMethod.NOAUTH ];
55 
56     string host;
57     ushort port;
58 
59     string username;
60     string password;
61 
62     bool resolveHost = true;
63 }
64 
65 alias SocksTCPConnector = bool delegate(in string host, in ushort port);
66 alias SocksDataReader = void delegate(ubyte[] data);
67 alias SocksDataWriter = void delegate(in ubyte[]);
68 alias SocksHostnameResolver = string function(in string hostname);
69 
70 struct Socks5
71 {
72     protected:
73         SocksTCPConnector connector;
74         SocksDataReader reader;
75         SocksDataWriter writer;
76         SocksHostnameResolver resolver;
77 
78         ReplyCode _replyCode = ReplyCode.UNKNOWN;
79 
80     public:
81         @nogc
82         this(SocksDataReader reader, SocksDataWriter writer, SocksTCPConnector connector = null, SocksHostnameResolver resolver = null)
83         {
84             this.connector = connector;
85             this.reader = reader;
86             this.writer = writer;
87             this.resolver = resolver;
88         }
89 
90         bool connect(in Socks5Options options, string host, ushort port)
91         {
92             if (connector !is null) {
93                 if (!connector(options.host, options.port)) {
94                     _replyCode = ReplyCode.NETWORK_UNREACHABLE;
95 
96                     return false;
97                 }
98             }
99 
100             AuthMethod chosenMethod = handshake(options);
101 
102             ReplyCode _replyCode = request(host, port, options.resolveHost);
103 
104             return _replyCode == ReplyCode.SUCCEEDED;
105         }
106 
107         @property
108         ReplyCode replyCode()
109         {
110             return _replyCode;
111         }
112 
113     protected:
114         AuthMethod handshake(in Socks5Options options)
115         {
116             ubyte[] data = [Socks5Version, cast(ubyte)options.authMethods.length];
117             writer(data);
118             writer(cast(ubyte[])options.authMethods);
119 
120             ubyte[2] answer;
121             reader(answer[]);
122 
123             assert(answer[0] == Socks5Version,
124                 "Error in reply from server. Protocol version must be 0x05 (see RFC 1928, chapter 3).");
125 
126             return cast(AuthMethod)answer[1];
127         }
128 
129         ReplyCode request(in string host, ushort port, bool resolveHostname)
130         {
131             Socks5RequestPacket packet;
132 
133             ubyte[] data = [
134                 Socks5Version,      // SOCKS version
135                 RequestCmd.CONNECT, // request command
136                 0x00,               // rsv
137                 AddressType.IPV4    // address type
138             ];
139             ubyte[] hostData;
140 
141             IpAddress address;
142 
143             if (resolveHostname && resolver !is null) {
144                 address = IpAddress(resolver(host));
145             } else { // no neeed to resolve address
146                 address = IpAddress(host);
147             }
148 
149             if (address.isIp4) {
150                 packet.setIp(address, port);
151             } else if (address.isIp6) {
152                 packet.setIp(address, port);
153             } else {
154                 packet.setDomain(host, port);
155             }
156 
157             writer(packet[]);
158 
159             ubyte[10] answer; // response packet size
160 
161             reader(answer[]);
162 
163             assert(answer[0] == Socks5Version,
164                 "Error in reply from server: protocol version must be 0x05 (see RFC 1928, chapter 6).");
165 
166             return cast(ReplyCode)answer[1];
167         }
168 }
169 
170 protected:
171     struct Socks5RequestPacket
172     {
173         private:
174         @safe:
175             struct PacketData
176             {
177                 ubyte socksVersion = Socks5Version;        // SOCKS version
178                 ubyte requestCommand = RequestCmd.CONNECT; // request command
179                 ubyte rsv = 0x00;                          // rsv
180                 ubyte addressType;                         // address type
181                 char[1 + ubyte.max + ushort.sizeof] hostData;
182             }
183 
184             union
185             {
186                 PacketData packetData;
187                 ubyte[PacketData.sizeof] buffer;
188             }
189             ushort hostDataLength;
190 
191         public:
192             void setDomain(string domain, ushort port)
193             in
194             {
195                 assert(domain.length <= ubyte.max);
196             }
197             do
198             {
199                 packetData.addressType = AddressType.DOMAIN;
200 
201                 hostDataLength = 1 + cast(ubyte)domain.length + ushort.sizeof;
202                 packetData.hostData[0] = cast(ubyte)(hostDataLength - 1 - ushort.sizeof);
203                 packetData.hostData[1..hostDataLength-2] = domain;
204 
205                 setPort(port);
206             }
207 
208             void setIp(IpAddress ipAddress, ushort port)
209             in
210             {
211                 assert(ipAddress.isIp4 || ipAddress.isIp6);
212             }
213             do
214             {
215                 import std.bitmanip : nativeToBigEndian;
216 
217                 if (ipAddress.isIp4) {
218                     packetData.addressType = AddressType.IPV4;
219                     hostDataLength = uint.sizeof + ushort.sizeof;
220                     packetData.hostData[0..uint.sizeof] = cast(char[])ipAddress.ip4.nativeToBigEndian;
221                 }
222                 if (ipAddress.isIp6) {
223                     packetData.addressType = AddressType.IPV6;
224                     hostDataLength = 16 + ushort.sizeof;
225                     packetData.hostData[0..16] = cast(char[])ipAddress.ip6; // TODO byte order?
226                 }
227 
228                 setPort(port);
229             }
230 
231             ubyte[] opSlice()
232             {
233                 return buffer[0 .. 4 + hostDataLength];
234             }
235 
236         protected:
237             void setPort(ushort port)
238             {
239                 import std.bitmanip : nativeToBigEndian;
240 
241                 packetData.hostData[hostDataLength-2..hostDataLength] = cast(char[])port.nativeToBigEndian();
242             }
243     }
244 
245     /**
246      * IP4 or IP6 address representation
247      */
248     struct IpAddress
249     {
250         import std.socket : InternetAddress, Internet6Address, SocketException;
251 
252         @safe:
253 
254         this(string ipString)
255         {
256             _ip4address = InternetAddress.parse(ipString);
257             if (_ip4address != InternetAddress.ADDR_NONE) {
258                 addressFamily = AddressFamily.INET;
259 
260                 return;
261             }
262 
263             try {
264                 ip6 = Internet6Address.parse(ipString);
265 
266                 return;
267             } catch (SocketException se) {
268 
269             }
270 
271             addressFamily = AddressFamily.UNSPEC;
272         }
273 
274         AddressFamily addressFamily;
275         union
276         {
277             uint      _ip4address;
278             ubyte[16] _ip6address;
279         }
280 
281         @property
282         void ip4(uint value)
283         {
284             addressFamily = AddressFamily.INET;
285             _ip4address = value;
286         }
287 
288         @property @safe
289         uint ip4()
290         {
291             return _ip4address;
292         }
293 
294         @property
295         bool isIp4()
296         {
297             return addressFamily == AddressFamily.INET;
298         }
299 
300         @property
301         void ip6(ubyte[16] value)
302         {
303             addressFamily = AddressFamily.INET6;
304             _ip6address = value;
305         }
306 
307         @property @safe
308         ubyte[16] ip6()
309         {
310             return _ip6address;
311         }
312 
313         @property
314         bool isIp6()
315         {
316             return addressFamily == AddressFamily.INET6;
317         }
318     }
319 
320     unittest
321     {
322         auto ip = IpAddress("127.0.0.1");
323         assert(ip.isIp4);
324         assert(ip.ip4 == 2130706433);
325 
326         import std.bigint;
327         ip = IpAddress("2001:0db8:85a3:0000:0000:8a2e:0370:7334");
328         assert(ip.isIp6);
329 
330         auto ipInt = BigInt(0);
331         foreach (i; 0..ip.ip6.length) {
332             ipInt += BigInt(ip.ip6[i]) << 8*(15-i);
333         }
334 
335         assert(ipInt == BigInt("42540766452641154071740215577757643572"));
336     }