aboutsummaryrefslogtreecommitdiff
path: root/plugins/ice/src/dtls_srtp.vala
blob: a21c242bcb009664a9877139c686a0988e338e36 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
using GnuTLS;

public class DtlsSrtp {

    public signal void send_data(uint8[] data);

    private X509.Certificate[] own_cert;
    private X509.PrivateKey private_key;
    private Cond buffer_cond = new Cond();
    private Mutex buffer_mutex = new Mutex();
    private Gee.LinkedList<Bytes> buffer_queue = new Gee.LinkedList<Bytes>();
    private uint pull_timeout = uint.MAX;
    private string peer_fingerprint;

    private Crypto.Srtp.Session encrypt_session;
    private Crypto.Srtp.Session decrypt_session;

    public static DtlsSrtp setup() throws GLib.Error {
        var obj = new DtlsSrtp();
        obj.generate_credentials();
        return obj;
    }

    internal string get_own_fingerprint(DigestAlgorithm digest_algo) {
        return format_certificate(own_cert[0], digest_algo);
    }

    public void set_peer_fingerprint(string fingerprint) {
        this.peer_fingerprint = fingerprint;
    }

    public uint8[] process_incoming_data(uint component_id, uint8[] data) {
        if (decrypt_session != null) {
            if (component_id == 1) return decrypt_session.decrypt_rtp(data);
            if (component_id == 2) return decrypt_session.decrypt_rtcp(data);
        } else if (component_id == 1) {
            on_data_rec(data);
        }
        return null;
    }

    public uint8[] process_outgoing_data(uint component_id, uint8[] data) {
        if (encrypt_session != null) {
            if (component_id == 1) return encrypt_session.encrypt_rtp(data);
            if (component_id == 2) return encrypt_session.encrypt_rtcp(data);
        }
        return null;
    }

    public void on_data_rec(owned uint8[] data) {
        buffer_mutex.lock();
        buffer_queue.add(new Bytes.take(data));
        buffer_cond.signal();
        buffer_mutex.unlock();
    }

    private void generate_credentials() throws GLib.Error {
        int err = 0;

        private_key = X509.PrivateKey.create();
        err = private_key.generate(PKAlgorithm.RSA, 2048);
        throw_if_error(err);

        var start_time = new DateTime.now_local().add_days(1);
        var end_time = start_time.add_days(2);

        X509.Certificate cert = X509.Certificate.create();
        cert.set_key(private_key);
        cert.set_version(1);
        cert.set_activation_time ((time_t) start_time.to_unix ());
        cert.set_expiration_time ((time_t) end_time.to_unix ());

        uint32 serial = 1;
        cert.set_serial(&serial, sizeof(uint32));

        cert.sign(cert, private_key);

        own_cert = new X509.Certificate[] { (owned)cert };
    }

    public async void setup_dtls_connection(bool server) {
        InitFlags server_or_client = server ? InitFlags.SERVER : InitFlags.CLIENT;
        debug("Setting up DTLS connection. We're %s", server_or_client.to_string());

        CertificateCredentials cert_cred = CertificateCredentials.create();
        int err = cert_cred.set_x509_key(own_cert, private_key);
        throw_if_error(err);

        Session? session = Session.create(server_or_client | InitFlags.DATAGRAM);
        session.enable_heartbeat(1);
        session.set_srtp_profile_direct("SRTP_AES128_CM_HMAC_SHA1_80");
        session.set_credentials(GnuTLS.CredentialsType.CERTIFICATE, cert_cred);
        session.server_set_request(CertificateRequest.REQUEST);
        session.set_priority_from_string("NORMAL:!VERS-TLS-ALL:+VERS-DTLS-ALL:+CTYPE-CLI-X509");

        session.set_transport_pointer(this);
        session.set_pull_function(pull_function);
        session.set_pull_timeout_function(pull_timeout_function);
        session.set_push_function(push_function);
        session.set_verify_function(verify_function);

        Thread<int> thread = new Thread<int> (null, () => {
            DateTime maximum_time = new DateTime.now_utc().add_seconds(20);
            do {
                err = session.handshake();

                DateTime current_time = new DateTime.now_utc();
                if (maximum_time.compare(current_time) < 0) {
                    warning("DTLS handshake timeouted");
                    return -1;
                }
            } while (err < 0 && !((ErrorCode)err).is_fatal());
            Idle.add(setup_dtls_connection.callback);
            return err;
        });
        yield;
        err = thread.join();

        uint8[] km = new uint8[150];
        Datum? client_key, client_salt, server_key, server_salt;
        session.get_srtp_keys(km, km.length, out client_key, out client_salt, out server_key, out server_salt);
        if (client_key == null || client_salt == null || server_key == null || server_salt == null) {
            warning("SRTP client/server key/salt null");
        }

        Crypto.Srtp.Session encrypt_session = new Crypto.Srtp.Session(Crypto.Srtp.Encryption.AES_CM, Crypto.Srtp.Authentication.HMAC_SHA1, 10, Crypto.Srtp.Prf.AES_CM, 0);
        Crypto.Srtp.Session decrypt_session = new Crypto.Srtp.Session(Crypto.Srtp.Encryption.AES_CM, Crypto.Srtp.Authentication.HMAC_SHA1, 10, Crypto.Srtp.Prf.AES_CM, 0);

        if (server) {
            encrypt_session.setkey(server_key.extract(), server_salt.extract());
            decrypt_session.setkey(client_key.extract(), client_salt.extract());
        } else {
            encrypt_session.setkey(client_key.extract(), client_salt.extract());
            decrypt_session.setkey(server_key.extract(), server_salt.extract());
        }

        this.encrypt_session = (owned)encrypt_session;
        this.decrypt_session = (owned)decrypt_session;
    }

    private static ssize_t pull_function(void* transport_ptr, uint8[] buffer) {
        DtlsSrtp self = transport_ptr as DtlsSrtp;

        self.buffer_mutex.lock();
        while (self.buffer_queue.size == 0) {
            self.buffer_cond.wait(self.buffer_mutex);
        }
        owned Bytes data = self.buffer_queue.remove_at(0);
        self.buffer_mutex.unlock();

        uint8[] data_uint8 = Bytes.unref_to_data(data);
        Memory.copy(buffer, data_uint8, data_uint8.length);

        // The callback should return 0 on connection termination, a positive number indicating the number of bytes received, and -1 on error.
        return (ssize_t)data.length;
    }

    private static int pull_timeout_function(void* transport_ptr, uint ms) {
        DtlsSrtp self = transport_ptr as DtlsSrtp;

        DateTime current_time = new DateTime.now_utc();
        current_time.add_seconds(ms/1000);
        int64 end_time = current_time.to_unix();

        self.buffer_mutex.lock();
        while (self.buffer_queue.size == 0) {
            self.buffer_cond.wait_until(self.buffer_mutex, end_time);

            DateTime new_current_time = new DateTime.now_utc();
            if (new_current_time.compare(current_time) > 0) {
                break;
            }
        }
        self.buffer_mutex.unlock();

        // The callback should return 0 on timeout, a positive number if data can be received, and -1 on error.
        return 1;
    }

    private static ssize_t push_function(void* transport_ptr, uint8[] buffer) {
        DtlsSrtp self = transport_ptr as DtlsSrtp;
        self.send_data(buffer);

        // The callback should return a positive number indicating the bytes sent, and -1 on error.
        return (ssize_t)buffer.length;
    }

    private static int verify_function(Session session) {
        DtlsSrtp self = session.get_transport_pointer() as DtlsSrtp;
        try {
            bool valid = self.verify_peer_cert(session);
            if (!valid) {
                warning("DTLS certificate invalid. Aborting handshake.");
                return 1;
            }
        } catch (Error e) {
            warning("Error during DTLS certificate validation: %s. Aborting handshake.", e.message);
            return 1;
        }

        // The callback function should return 0 for the handshake to continue or non-zero to terminate.
        return 0;
    }

    private bool verify_peer_cert(Session session) throws GLib.Error {
        unowned Datum[] cert_datums = session.get_peer_certificates();
        if (cert_datums.length == 0) {
            warning("No peer certs");
            return false;
        }
        if (cert_datums.length > 1) warning("More than one peer cert");

        X509.Certificate peer_cert = X509.Certificate.create();
        peer_cert.import(ref cert_datums[0], CertificateFormat.DER);

        string peer_fp_str = format_certificate(peer_cert, DigestAlgorithm.SHA256);
        if (peer_fp_str.down() != this.peer_fingerprint.down()) {
            warning("First cert in peer cert list doesn't equal advertised one %s vs %s", peer_fp_str, this.peer_fingerprint);
            return false;
        }

        return true;
    }

    private string format_certificate(X509.Certificate certificate, DigestAlgorithm digest_algo) {
        uint8[] buf = new uint8[512];
        size_t buf_out_size = 512;
        certificate.get_fingerprint(digest_algo, buf, ref buf_out_size);

        var sb = new StringBuilder();
        for (int i = 0; i < buf_out_size; i++) {
            sb.append("%02x".printf(buf[i]));
            if (i < buf_out_size - 1) {
                sb.append(":");
            }
        }
        return sb.str;
    }

    private uint8[] uint8_pt_to_a(uint8* data, uint size) {
        uint8[size] ret = new uint8[size];
        for (int i = 0; i < size; i++) {
            ret[i] = data[i];
        }
        return ret;
    }
}