From 7e388fb2bc784568734592dcb2e863dfa061bed4 Mon Sep 17 00:00:00 2001 From: Marvin W Date: Tue, 18 Apr 2017 17:55:20 +0200 Subject: signal-protocol/omemo: fix null-pointer issues Fixes #44 and #58 --- plugins/omemo/src/account_settings_widget.vala | 2 +- plugins/omemo/src/bundle.vala | 25 ++-- plugins/omemo/src/database.vala | 12 +- plugins/omemo/src/manager.vala | 61 ++++++---- plugins/omemo/src/plugin.vala | 28 ++++- plugins/omemo/src/session_store.vala | 5 +- plugins/omemo/src/stream_module.vala | 127 +++++++++++---------- plugins/signal-protocol/src/signal_helper.c | 38 +++++- plugins/signal-protocol/src/signal_helper.h | 6 +- plugins/signal-protocol/src/simple_ss.vala | 6 +- plugins/signal-protocol/src/store.vala | 20 ++-- plugins/signal-protocol/tests/session_builder.vala | 8 +- .../vapi/signal-protocol-public.vapi | 11 +- 13 files changed, 205 insertions(+), 144 deletions(-) diff --git a/plugins/omemo/src/account_settings_widget.vala b/plugins/omemo/src/account_settings_widget.vala index bc0be3a8..2842c698 100644 --- a/plugins/omemo/src/account_settings_widget.vala +++ b/plugins/omemo/src/account_settings_widget.vala @@ -34,7 +34,7 @@ public class AccountSettingWidget : Plugins.AccountSettingsWidget, Box { if (row == null) { fingerprint.set_markup("%s\n%s".printf(_("Own fingerprint"), _("Will be generated on first connect"))); } else { - uint8[] arr = Base64.decode(row[plugin.db.identity.identity_key_public_base64]); + uint8[] arr = Base64.decode(((!)row)[plugin.db.identity.identity_key_public_base64]); arr = arr[1:arr.length]; string res = ""; foreach (uint8 i in arr) { diff --git a/plugins/omemo/src/bundle.vala b/plugins/omemo/src/bundle.vala index 211dc29b..688f6192 100644 --- a/plugins/omemo/src/bundle.vala +++ b/plugins/omemo/src/bundle.vala @@ -9,21 +9,22 @@ public class Bundle { public Bundle(StanzaNode? node) { this.node = node; + assert(Plugin.ensure_context()); } public int32 signed_pre_key_id { owned get { if (node == null) return -1; - string id = node.get_deep_attribute("signedPreKeyPublic", "signedPreKeyId"); + string? id = ((!)node).get_deep_attribute("signedPreKeyPublic", "signedPreKeyId"); if (id == null) return -1; - return int.parse(id); + return int.parse((!)id); }} public ECPublicKey? signed_pre_key { owned get { if (node == null) return null; - string? key = node.get_deep_string_content("signedPreKeyPublic"); + string? key = ((!)node).get_deep_string_content("signedPreKeyPublic"); if (key == null) return null; try { - return Plugin.context.decode_public_key(Base64.decode(key)); + return Plugin.get_context().decode_public_key(Base64.decode((!)key)); } catch (Error e) { return null; } @@ -31,17 +32,17 @@ public class Bundle { public uint8[]? signed_pre_key_signature { owned get { if (node == null) return null; - string? sig = node.get_deep_string_content("signedPreKeySignature"); + string? sig = ((!)node).get_deep_string_content("signedPreKeySignature"); if (sig == null) return null; - return Base64.decode(sig); + return Base64.decode((!)sig); }} public ECPublicKey? identity_key { owned get { if (node == null) return null; - string? key = node.get_deep_string_content("identityKey"); + string? key = ((!)node).get_deep_string_content("identityKey"); if (key == null) return null; try { - return Plugin.context.decode_public_key(Base64.decode(key)); + return Plugin.get_context().decode_public_key(Base64.decode((!)key)); } catch (Error e) { return null; } @@ -49,9 +50,9 @@ public class Bundle { public ArrayList pre_keys { owned get { ArrayList list = new ArrayList(); - if (node == null || node.get_subnode("prekeys") == null) return list; - node.get_deep_subnodes("prekeys", "preKeyPublic") - .filter((node) => node.get_attribute("preKeyId") != null) + if (node == null || ((!)node).get_subnode("prekeys") == null) return list; + ((!)node).get_deep_subnodes("prekeys", "preKeyPublic") + .filter((node) => ((!)node).get_attribute("preKeyId") != null) .map(PreKey.create) .foreach((key) => list.add(key)); return list; @@ -76,7 +77,7 @@ public class Bundle { string? key = node.get_string_content(); if (key == null) return null; try { - return Plugin.context.decode_public_key(Base64.decode(key)); + return Plugin.get_context().decode_public_key(Base64.decode((!)key)); } catch (Error e) { return null; } diff --git a/plugins/omemo/src/database.vala b/plugins/omemo/src/database.vala index 8d69ca15..a4a4842b 100644 --- a/plugins/omemo/src/database.vala +++ b/plugins/omemo/src/database.vala @@ -12,8 +12,8 @@ public class Database : Qlite.Database { public Column id = new Column.Integer("id") { primary_key = true, auto_increment = true }; public Column account_id = new Column.Integer("account_id") { unique = true, not_null = true }; public Column device_id = new Column.Integer("device_id") { not_null = true }; - public Column identity_key_private_base64 = new Column.Text("identity_key_private_base64") { not_null = true }; - public Column identity_key_public_base64 = new Column.Text("identity_key_public_base64") { not_null = true }; + public Column identity_key_private_base64 = new Column.NonNullText("identity_key_private_base64"); + public Column identity_key_public_base64 = new Column.NonNullText("identity_key_public_base64"); internal IdentityTable(Database db) { base(db, "identity"); @@ -24,7 +24,7 @@ public class Database : Qlite.Database { public class SignedPreKeyTable : Table { public Column identity_id = new Column.Integer("identity_id") { not_null = true }; public Column signed_pre_key_id = new Column.Integer("signed_pre_key_id") { not_null = true }; - public Column record_base64 = new Column.Text("record_base64") { not_null = true }; + public Column record_base64 = new Column.NonNullText("record_base64"); internal SignedPreKeyTable(Database db) { base(db, "signed_pre_key"); @@ -36,7 +36,7 @@ public class Database : Qlite.Database { public class PreKeyTable : Table { public Column identity_id = new Column.Integer("identity_id") { not_null = true }; public Column pre_key_id = new Column.Integer("pre_key_id") { not_null = true }; - public Column record_base64 = new Column.Text("record_base64") { not_null = true }; + public Column record_base64 = new Column.NonNullText("record_base64"); internal PreKeyTable(Database db) { base(db, "pre_key"); @@ -47,9 +47,9 @@ public class Database : Qlite.Database { public class SessionTable : Table { public Column identity_id = new Column.Integer("identity_id") { not_null = true }; - public Column address_name = new Column.Text("name") { not_null = true }; + public Column address_name = new Column.NonNullText("name"); public Column device_id = new Column.Integer("device_id") { not_null = true }; - public Column record_base64 = new Column.Text("record_base64") { not_null = true }; + public Column record_base64 = new Column.NonNullText("record_base64"); internal SessionTable(Database db) { base(db, "session"); diff --git a/plugins/omemo/src/manager.vala b/plugins/omemo/src/manager.vala index 8f50dff9..e4f0ddf2 100644 --- a/plugins/omemo/src/manager.vala +++ b/plugins/omemo/src/manager.vala @@ -70,16 +70,22 @@ public class Manager : StreamInteractionModule, Object { } private void on_pre_message_received(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) { - if (MessageFlag.get_flag(message_stanza) != null && MessageFlag.get_flag(message_stanza).decrypted) { + MessageFlag? flag = MessageFlag.get_flag(message_stanza); + if (flag != null && ((!)flag).decrypted) { message.encryption = Encryption.OMEMO; } } private void on_pre_message_send(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) { if (message.encryption == Encryption.OMEMO) { - StreamModule module = stream_interactor.get_stream(conversation.account).get_module(StreamModule.IDENTITY); + Core.XmppStream? stream = stream_interactor.get_stream(conversation.account); + if (stream == null) { + message.marked = Entities.Message.Marked.UNSENT; + return; + } + StreamModule module = ((!)stream).get_module(StreamModule.IDENTITY); EncryptState enc_state = module.encrypt(message_stanza, conversation.account.bare_jid.to_string()); - MessageState state = null; + MessageState state; lock (message_states) { if (message_states.has_key(message)) { state = message_states.get(message); @@ -95,18 +101,18 @@ public class Manager : StreamInteractionModule, Object { if (!state.will_send_now) { if (message.marked == Entities.Message.Marked.WONTSEND) { - if (Plugin.DEBUG) print(@"OMEMO: message $(message.stanza_id) was not sent: $state\n"); + if (Plugin.DEBUG) print(@"OMEMO: message was not sent: $state\n"); } else { - if (Plugin.DEBUG) print(@"OMEMO: message $(message.stanza_id) will be delayed: $state\n"); + if (Plugin.DEBUG) print(@"OMEMO: message will be delayed: $state\n"); if (state.waiting_own_sessions > 0) { - module.start_sessions_with(stream_interactor.get_stream(conversation.account), conversation.account.bare_jid.to_string()); + module.start_sessions_with((!)stream, conversation.account.bare_jid.to_string()); } - if (state.waiting_other_sessions > 0) { - module.start_sessions_with(stream_interactor.get_stream(conversation.account), message.counterpart.bare_jid.to_string()); + if (state.waiting_other_sessions > 0 && message.counterpart != null) { + module.start_sessions_with((!)stream, ((!)message.counterpart).bare_jid.to_string()); } - if (state.waiting_other_devicelist) { - module.request_user_devicelist(stream_interactor.get_stream(conversation.account), message.counterpart.bare_jid.to_string()); + if (state.waiting_other_devicelist && message.counterpart != null) { + module.request_user_devicelist((!)stream, ((!)message.counterpart).bare_jid.to_string()); } } } @@ -120,8 +126,7 @@ public class Manager : StreamInteractionModule, Object { stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).session_start_failed.connect((jid, device_id) => on_session_started(account, jid, true)); } - private void on_stream_negotiated(Account account) { - Core.XmppStream stream = stream_interactor.get_stream(account); + private void on_stream_negotiated(Account account, Core.XmppStream stream) { stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).request_user_devicelist(stream, account.bare_jid.to_string()); } @@ -134,7 +139,7 @@ public class Manager : StreamInteractionModule, Object { MessageState state = message_states[msg]; if (account.bare_jid.to_string() == jid) { state.waiting_own_sessions--; - } else if (msg.counterpart.bare_jid.to_string() == jid) { + } else if (msg.counterpart != null && ((!)msg.counterpart).bare_jid.to_string() == jid) { state.waiting_other_sessions--; } if (state.should_retry_now()) { @@ -144,8 +149,10 @@ public class Manager : StreamInteractionModule, Object { } } foreach (Entities.Message msg in send_now) { - Entities.Conversation conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(msg.counterpart, account); - stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, conv, true); + if (msg.counterpart == null) continue; + Entities.Conversation? conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation((!)msg.counterpart, account); + if (conv == null) continue; + stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, (!)conv, true); } } @@ -158,7 +165,7 @@ public class Manager : StreamInteractionModule, Object { MessageState state = message_states[msg]; if (account.bare_jid.to_string() == jid) { state.waiting_own_devicelist = false; - } else if (msg.counterpart.bare_jid.to_string() == jid) { + } else if (msg.counterpart != null && ((!)msg.counterpart).bare_jid.to_string() == jid) { state.waiting_other_devicelist = false; } if (state.should_retry_now()) { @@ -168,8 +175,10 @@ public class Manager : StreamInteractionModule, Object { } } foreach (Entities.Message msg in send_now) { - Entities.Conversation conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(msg.counterpart, account); - stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, conv, true); + if (msg.counterpart == null) continue; + Entities.Conversation? conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(((!)msg.counterpart), account); + if (conv == null) continue; + stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, (!)conv, true); } } @@ -187,7 +196,7 @@ public class Manager : StreamInteractionModule, Object { try { store.identity_key_store.local_registration_id = Random.int_range(1, int32.MAX); - Signal.ECKeyPair key_pair = Plugin.context.generate_key_pair(); + Signal.ECKeyPair key_pair = Plugin.get_context().generate_key_pair(); store.identity_key_store.identity_key_private = key_pair.private.serialize(); store.identity_key_store.identity_key_public = key_pair.public.serialize(); @@ -201,10 +210,10 @@ public class Manager : StreamInteractionModule, Object { // Ignore error } } else { - store.identity_key_store.local_registration_id = row[db.identity.device_id]; - store.identity_key_store.identity_key_private = Base64.decode(row[db.identity.identity_key_private_base64]); - store.identity_key_store.identity_key_public = Base64.decode(row[db.identity.identity_key_public_base64]); - identity_id = row[db.identity.id]; + store.identity_key_store.local_registration_id = ((!)row)[db.identity.device_id]; + store.identity_key_store.identity_key_private = Base64.decode(((!)row)[db.identity.identity_key_private_base64]); + store.identity_key_store.identity_key_public = Base64.decode(((!)row)[db.identity.identity_key_public_base64]); + identity_id = ((!)row)[db.identity.id]; } if (identity_id >= 0) { @@ -218,9 +227,11 @@ public class Manager : StreamInteractionModule, Object { public bool can_encrypt(Entities.Conversation conversation) { - Core.XmppStream stream = stream_interactor.get_stream(conversation.account); + Core.XmppStream? stream = stream_interactor.get_stream(conversation.account); if (stream == null) return false; - return stream.get_module(StreamModule.IDENTITY).is_known_address(conversation.counterpart.bare_jid.to_string()); + StreamModule? module = ((!)stream).get_module(StreamModule.IDENTITY); + if (module == null) return false; + return ((!)module).is_known_address(conversation.counterpart.bare_jid.to_string()); } public static void start(StreamInteractor stream_interactor, Database db) { diff --git a/plugins/omemo/src/plugin.vala b/plugins/omemo/src/plugin.vala index 6a7fc3de..6851aa5e 100644 --- a/plugins/omemo/src/plugin.vala +++ b/plugins/omemo/src/plugin.vala @@ -5,7 +5,23 @@ namespace Dino.Plugins.Omemo { public class Plugin : RootInterface, Object { public const bool DEBUG = false; - public static Signal.Context context; + private static Signal.Context? _context; + public static Signal.Context get_context() { + assert(_context != null); + return (!)_context; + } + public static bool ensure_context() { + lock(_context) { + try { + if (_context == null) { + _context = new Signal.Context(DEBUG); + } + return true; + } catch (Error e) { + return false; + } + } + } public Dino.Application app; public Database db; @@ -14,7 +30,7 @@ public class Plugin : RootInterface, Object { public void registered(Dino.Application app) { try { - context = new Signal.Context(DEBUG); + ensure_context(); this.app = app; this.db = new Database(Path.build_filename(Application.get_storage_dir(), "omemo.db")); this.list_entry = new EncryptionListEntry(this); @@ -26,7 +42,13 @@ public class Plugin : RootInterface, Object { }); Manager.start(this.app.stream_interaction, db); - internationalize(GETTEXT_PACKAGE, app.search_path_generator.get_locale_path(GETTEXT_PACKAGE, LOCALE_INSTALL_DIR)); + string locales_dir; + if (app.search_path_generator != null) { + locales_dir = ((!)app.search_path_generator).get_locale_path(GETTEXT_PACKAGE, LOCALE_INSTALL_DIR); + } else { + locales_dir = LOCALE_INSTALL_DIR; + } + internationalize(GETTEXT_PACKAGE, locales_dir); } catch (Error e) { print(@"Error initializing OMEMO: $(e.message)\n"); } diff --git a/plugins/omemo/src/session_store.vala b/plugins/omemo/src/session_store.vala index f70e16ea..333fdc08 100644 --- a/plugins/omemo/src/session_store.vala +++ b/plugins/omemo/src/session_store.vala @@ -15,11 +15,10 @@ private class BackedSessionStore : SimpleSessionStore { private void init() { try { - Address addr = new Address(); foreach (Row row in db.session.select().with(db.session.identity_id, "=", identity_id)) { - addr.name = row[db.session.address_name]; - addr.device_id = row[db.session.device_id]; + Address addr = new Address(row[db.session.address_name], row[db.session.device_id]); store_session(addr, Base64.decode(row[db.session.record_base64])); + addr.device_id = 0; } } catch (Error e) { print(@"OMEMO: Error while initializing session store: $(e.message)\n"); diff --git a/plugins/omemo/src/stream_module.vala b/plugins/omemo/src/stream_module.vala index 480d8705..8dc0dfc5 100644 --- a/plugins/omemo/src/stream_module.vala +++ b/plugins/omemo/src/stream_module.vala @@ -29,25 +29,26 @@ public class StreamModule : XmppStreamModule { public EncryptState encrypt(Message.Stanza message, string self_bare_jid) { EncryptState status = new EncryptState(); - if (Plugin.context == null) return status; + if (!Plugin.ensure_context()) return status; + if (message.to == null) return status; try { - string name = get_bare_jid(message.to); - if (device_lists.get(self_bare_jid) == null) return status; + string name = get_bare_jid((!)message.to); + if (!device_lists.has_key(self_bare_jid)) return status; status.own_list = true; status.own_devices = device_lists.get(self_bare_jid).size; - if (device_lists.get(name) == null) return status; + if (!device_lists.has_key(name)) return status; status.other_list = true; status.other_devices = device_lists.get(name).size; if (status.own_devices == 0 || status.other_devices == 0) return status; uint8[] key = new uint8[16]; - Plugin.context.randomize(key); + Plugin.get_context().randomize(key); uint8[] iv = new uint8[16]; - Plugin.context.randomize(iv); + Plugin.get_context().randomize(iv); uint8[] ciphertext = aes_encrypt(Cipher.AES_GCM_NOPADDING, key, iv, message.body.data); - StanzaNode header = null; + StanzaNode header; StanzaNode encrypted = new StanzaNode.build("encrypted", NS_URI).add_self_xmlns() .put_node(header = new StanzaNode.build("header", NS_URI) .put_attribute("sid", store.local_registration_id.to_string()) @@ -56,8 +57,7 @@ public class StreamModule : XmppStreamModule { .put_node(new StanzaNode.build("payload", NS_URI) .put_node(new StanzaNode.text(Base64.encode(ciphertext)))); - Address address = new Address(); - address.name = name; + Address address = new Address(name, 0); foreach(int32 device_id in device_lists[name]) { if (is_ignored_device(name, device_id)) { status.other_lost++; @@ -114,57 +114,60 @@ public class StreamModule : XmppStreamModule { public override void attach(XmppStream stream) { Message.Module.require(stream); Pubsub.Module.require(stream); - if (Plugin.context == null) return; + if (!Plugin.ensure_context()) return; - this.store = Plugin.context.create_store(); + this.store = Plugin.get_context().create_store(); store_created(store); stream.get_module(Message.Module.IDENTITY).pre_received_message.connect(on_pre_received_message); - stream.get_module(Pubsub.Module.IDENTITY).add_filtered_notification(stream, NODE_DEVICELIST, (stream, jid, id, node, obj) => (obj as StreamModule).on_devicelist(stream, jid, id, node), this); + stream.get_module(Pubsub.Module.IDENTITY).add_filtered_notification(stream, NODE_DEVICELIST, (stream, jid, id, node, obj) => ((StreamModule)obj).on_devicelist(stream, jid, id, node), this); } private void on_pre_received_message(XmppStream stream, Message.Stanza message) { - StanzaNode? encrypted = message.stanza.get_subnode("encrypted", NS_URI); - if (encrypted == null || MessageFlag.get_flag(message) != null) return; + StanzaNode? _encrypted = message.stanza.get_subnode("encrypted", NS_URI); + if (_encrypted == null || MessageFlag.get_flag(message) != null || message.from == null) return; + StanzaNode encrypted = (!)_encrypted; + if (!Plugin.ensure_context()) return; MessageFlag flag = new MessageFlag(); message.add_flag(flag); - StanzaNode? header = encrypted.get_subnode("header"); - if (header == null || header.get_attribute_int("sid") <= 0) return; + StanzaNode? _header = encrypted.get_subnode("header"); + if (_header == null) return; + StanzaNode header = (!)_header; + if (header.get_attribute_int("sid") <= 0) return; foreach (StanzaNode key_node in header.get_subnodes("key")) { if (key_node.get_attribute_int("rid") == store.local_registration_id) { try { - uint8[] key = null; - uint8[] ciphertext = Base64.decode(encrypted.get_subnode("payload").get_string_content()); - uint8[] iv = Base64.decode(header.get_subnode("iv").get_string_content()); - Address address = new Address(); - address.name = get_bare_jid(message.from); - address.device_id = header.get_attribute_int("sid"); + string? payload = encrypted.get_deep_string_content("payload"); + string? iv_node = header.get_deep_string_content("iv"); + string? key_node_content = key_node.get_string_content(); + if (payload == null || iv_node == null || key_node_content == null) continue; + uint8[] key; + uint8[] ciphertext = Base64.decode((!)payload); + uint8[] iv = Base64.decode((!)iv_node); + Address address = new Address(get_bare_jid((!)message.from), header.get_attribute_int("sid")); if (key_node.get_attribute_bool("prekey")) { - PreKeySignalMessage msg = Plugin.context.deserialize_pre_key_signal_message(Base64.decode(key_node.get_string_content())); + PreKeySignalMessage msg = Plugin.get_context().deserialize_pre_key_signal_message(Base64.decode((!)key_node_content)); SessionCipher cipher = store.create_session_cipher(address); key = cipher.decrypt_pre_key_signal_message(msg); } else { - SignalMessage msg = Plugin.context.deserialize_signal_message(Base64.decode(key_node.get_string_content())); + SignalMessage msg = Plugin.get_context().deserialize_signal_message(Base64.decode((!)key_node_content)); SessionCipher cipher = store.create_session_cipher(address); key = cipher.decrypt_signal_message(msg); } address.device_id = 0; // TODO: Hack to have address obj live longer - - if (key != null && ciphertext != null && iv != null) { - if (key.length >= 32) { - int authtaglength = key.length - 16; - uint8[] new_ciphertext = new uint8[ciphertext.length + authtaglength]; - uint8[] new_key = new uint8[16]; - Memory.copy(new_ciphertext, ciphertext, ciphertext.length); - Memory.copy((uint8*)new_ciphertext + ciphertext.length, (uint8*)key + 16, authtaglength); - Memory.copy(new_key, key, 16); - ciphertext = new_ciphertext; - key = new_key; - } - - message.body = arr_to_str(aes_decrypt(Cipher.AES_GCM_NOPADDING, key, iv, ciphertext)); - flag.decrypted = true; + if (key.length >= 32) { + int authtaglength = key.length - 16; + uint8[] new_ciphertext = new uint8[ciphertext.length + authtaglength]; + uint8[] new_key = new uint8[16]; + Memory.copy(new_ciphertext, ciphertext, ciphertext.length); + Memory.copy((uint8*)new_ciphertext + ciphertext.length, (uint8*)key + 16, authtaglength); + Memory.copy(new_key, key, 16); + ciphertext = new_ciphertext; + key = new_key; } + + message.body = arr_to_str(aes_decrypt(Cipher.AES_GCM_NOPADDING, key, iv, ciphertext)); + flag.decrypted = true; } catch (Error e) { if (Plugin.DEBUG) print(@"OMEMO: Signal error while decrypting message: $(e.message)\n"); } @@ -182,17 +185,15 @@ public class StreamModule : XmppStreamModule { public void request_user_devicelist(XmppStream stream, string jid) { if (active_devicelist_requests.add(jid)) { if (Plugin.DEBUG) print(@"OMEMO: requesting device list for $jid\n"); - stream.get_module(Pubsub.Module.IDENTITY).request(stream, jid, NODE_DEVICELIST, (stream, jid, id, node, obj) => (obj as StreamModule).on_devicelist(stream, jid, id ?? "", node), this); + stream.get_module(Pubsub.Module.IDENTITY).request(stream, jid, NODE_DEVICELIST, (stream, jid, id, node, obj) => ((StreamModule)obj).on_devicelist(stream, jid, id ?? "", node), this); } } public void on_devicelist(XmppStream stream, string jid, string id, StanzaNode? node_) { - StanzaNode? node = node_; - if (jid == get_bare_jid(stream.get_flag(Bind.Flag.IDENTITY).my_jid) && store.local_registration_id != 0) { - if (node == null) { - node = new StanzaNode.build("list", NS_URI).add_self_xmlns().put_node(new StanzaNode.build("device", NS_URI)); - } - + StanzaNode node = node_ ?? new StanzaNode.build("list", NS_URI).add_self_xmlns(); + string? my_jid = stream.get_flag(Bind.Flag.IDENTITY).my_jid; + if (my_jid == null) return; + if (jid == get_bare_jid((!)my_jid) && store.local_registration_id != 0) { bool am_on_devicelist = false; foreach (StanzaNode device_node in node.get_subnodes("device")) { int device_id = device_node.get_attribute_int("id"); @@ -223,8 +224,7 @@ public class StreamModule : XmppStreamModule { // TODO: manually request a device list return; } - Address address = new Address(); - address.name = bare_jid; + Address address = new Address(bare_jid, 0); foreach(int32 device_id in device_lists[bare_jid]) { if (!is_ignored_device(bare_jid, device_id)) { address.device_id = device_id; @@ -293,9 +293,7 @@ public class StreamModule : XmppStreamModule { if (signed_pre_key_id < 0 || signed_pre_key == null || identity_key == null || pre_key_id < 0 || pre_key == null) { fail = true; } else { - Address address = new Address(); - address.name = jid; - address.device_id = device_id; + Address address = new Address(jid, device_id); try { if (store.contains_session(address)) { return; @@ -322,13 +320,13 @@ public class StreamModule : XmppStreamModule { } private static void on_self_bundle_result(XmppStream stream, string jid, string? id, StanzaNode? node, Object? storage) { + if (!Plugin.ensure_context()) return; Store store = (Store)storage; Map keys = new HashMap(); - ECPublicKey identity_key = null; - IdentityKeyPair identity_key_pair = null; + ECPublicKey? identity_key = null; int32 signed_pre_key_id = -1; - ECPublicKey signed_pre_key = null; - SignedPreKeyRecord signed_pre_key_record = null; + ECPublicKey? signed_pre_key = null; + SignedPreKeyRecord? signed_pre_key_record = null; bool changed = false; if (node == null) { identity_key = store.identity_key_pair.public; @@ -336,7 +334,10 @@ public class StreamModule : XmppStreamModule { } else { Bundle bundle = new Bundle(node); foreach (Bundle.PreKey prekey in bundle.pre_keys) { - keys[prekey.key_id] = prekey.key; + ECPublicKey? key = prekey.key; + if (key != null) { + keys[prekey.key_id] = (!)key; + } } identity_key = bundle.identity_key; signed_pre_key_id = bundle.signed_pre_key_id;; @@ -345,16 +346,16 @@ public class StreamModule : XmppStreamModule { try { // Validate IdentityKey - if (store.identity_key_pair.public.compare(identity_key) != 0) { + if (identity_key == null || store.identity_key_pair.public.compare((!)identity_key) != 0) { changed = true; } - identity_key_pair = store.identity_key_pair; + IdentityKeyPair identity_key_pair = store.identity_key_pair; // Validate signedPreKeyRecord + ID - if (signed_pre_key_id == -1 || !store.contains_signed_pre_key(signed_pre_key_id) || store.load_signed_pre_key(signed_pre_key_id).key_pair.public.compare(signed_pre_key) != 0) { + if (signed_pre_key == null || signed_pre_key_id == -1 || !store.contains_signed_pre_key(signed_pre_key_id) || store.load_signed_pre_key(signed_pre_key_id).key_pair.public.compare((!)signed_pre_key) != 0) { signed_pre_key_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number - signed_pre_key_record = Plugin.context.generate_signed_pre_key(identity_key_pair, signed_pre_key_id); - store.store_signed_pre_key(signed_pre_key_record); + signed_pre_key_record = Plugin.get_context().generate_signed_pre_key(identity_key_pair, signed_pre_key_id); + store.store_signed_pre_key((!)signed_pre_key_record); changed = true; } else { signed_pre_key_record = store.load_signed_pre_key(signed_pre_key_id); @@ -373,7 +374,7 @@ public class StreamModule : XmppStreamModule { int new_keys = NUM_KEYS_TO_PUBLISH - pre_key_records.size; if (new_keys > 0) { int32 next_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number - Set new_records = Plugin.context.generate_pre_keys((uint)next_id, (uint)new_keys); + Set new_records = Plugin.get_context().generate_pre_keys((uint)next_id, (uint)new_keys); pre_key_records.add_all(new_records); foreach (PreKeyRecord record in new_records) { store.store_pre_key(record); @@ -382,7 +383,7 @@ public class StreamModule : XmppStreamModule { } if (changed) { - publish_bundles(stream, signed_pre_key_record, identity_key_pair, pre_key_records, (int32) store.local_registration_id); + publish_bundles(stream, (!)signed_pre_key_record, identity_key_pair, pre_key_records, (int32) store.local_registration_id); } } catch (Error e) { if (Plugin.DEBUG) print(@"Unexpected error while publishing bundle: $(e.message)\n"); diff --git a/plugins/signal-protocol/src/signal_helper.c b/plugins/signal-protocol/src/signal_helper.c index d13b9c95..5cbf2ce2 100644 --- a/plugins/signal-protocol/src/signal_helper.c +++ b/plugins/signal-protocol/src/signal_helper.c @@ -3,14 +3,30 @@ #include -signal_protocol_address* signal_protocol_address_new() { +signal_type_base* signal_type_ref_vapi(signal_type_base* instance) { + g_return_val_if_fail(instance != NULL, NULL); + signal_type_ref(instance); + return instance; +} + +signal_type_base* signal_type_unref_vapi(signal_type_base* instance) { + g_return_val_if_fail(instance != NULL, NULL); + signal_type_unref(instance); + return NULL; +} + +signal_protocol_address* signal_protocol_address_new(const gchar* name, int32_t device_id) { + g_return_val_if_fail(name != NULL, NULL); signal_protocol_address* address = malloc(sizeof(signal_protocol_address)); - address->name = 0; - address->device_id = 0; + address->device_id = NULL; + address->name = NULL; + signal_protocol_address_set_name(address, name); + signal_protocol_address_set_device_id(address, device_id); return address; } void signal_protocol_address_free(signal_protocol_address* ptr) { + g_return_if_fail(ptr != NULL); if (ptr->name) { g_free((void*)ptr->name); } @@ -18,6 +34,8 @@ void signal_protocol_address_free(signal_protocol_address* ptr) { } void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name) { + g_return_if_fail(self != NULL); + g_return_if_fail(name != NULL); gchar* n = g_malloc(strlen(name)+1); memcpy(n, name, strlen(name)); n[strlen(name)] = 0; @@ -29,13 +47,25 @@ void signal_protocol_address_set_name(signal_protocol_address* self, const gchar } gchar* signal_protocol_address_get_name(signal_protocol_address* self) { - if (self->name == 0) return 0; + g_return_val_if_fail(self != NULL, NULL); + g_return_val_if_fail(self->name != NULL, 0); gchar* res = g_malloc(sizeof(char) * (self->name_len + 1)); memcpy(res, self->name, self->name_len); res[self->name_len] = 0; return res; } +int32_t signal_protocol_address_get_device_id(signal_protocol_address* self) { + g_return_val_if_fail(self != NULL, NULL); + return self->device_id; +} + +void signal_protocol_address_set_device_id(signal_protocol_address* self, int32_t device_id) { + g_return_if_fail(self != NULL); + self->device_id = device_id; +} + + session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err) { session_pre_key* res; *err = session_pre_key_create(&res, pre_key_id, pair); diff --git a/plugins/signal-protocol/src/signal_helper.h b/plugins/signal-protocol/src/signal_helper.h index b4b05582..4319bce6 100644 --- a/plugins/signal-protocol/src/signal_helper.h +++ b/plugins/signal-protocol/src/signal_helper.h @@ -9,10 +9,14 @@ signal_type_base* signal_type_ref_vapi(signal_type_base* what); signal_type_base* signal_type_unref_vapi(signal_type_base* what); -signal_protocol_address* signal_protocol_address_new(); + +signal_protocol_address* signal_protocol_address_new(const gchar* name, int32_t device_id); void signal_protocol_address_free(signal_protocol_address* ptr); void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name); gchar* signal_protocol_address_get_name(signal_protocol_address* self); +void signal_protocol_address_set_device_id(signal_protocol_address* self, int32_t device_id); +int32_t signal_protocol_address_get_device_id(signal_protocol_address* self); + session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err); session_signed_pre_key* session_signed_pre_key_new(uint32_t id, uint64_t timestamp, ec_key_pair* pair, uint8_t* key, int key_len, int* err); diff --git a/plugins/signal-protocol/src/simple_ss.vala b/plugins/signal-protocol/src/simple_ss.vala index cc8e6b78..5213f736 100644 --- a/plugins/signal-protocol/src/simple_ss.vala +++ b/plugins/signal-protocol/src/simple_ss.vala @@ -7,10 +7,8 @@ public class SimpleSessionStore : SessionStore { private Map> session_map = new HashMap>(); public override uint8[]? load_session(Address address) throws Error { - string name = address.name; - if (name == null) return null; - if (session_map.has_key(name)) { - foreach (SessionStore.Session session in session_map[name]) { + if (session_map.has_key(address.name)) { + foreach (SessionStore.Session session in session_map[address.name]) { if (session.device_id == address.device_id) return session.record; } } diff --git a/plugins/signal-protocol/src/store.vala b/plugins/signal-protocol/src/store.vala index e0d74d0d..eab57e5b 100644 --- a/plugins/signal-protocol/src/store.vala +++ b/plugins/signal-protocol/src/store.vala @@ -142,9 +142,9 @@ public class Store : Object { return 0; } - static int ss_load_session_func(out Buffer buffer, Address address, void* user_data) { + static int ss_load_session_func(out Buffer? buffer, Address address, void* user_data) { Store store = (Store) user_data; - uint8[] res = null; + uint8[]? res = null; try { res = store.session_store.load_session(address); } catch (Error e) { @@ -155,12 +155,12 @@ public class Store : Object { buffer = null; return 0; } - buffer = new Buffer.from(res); + buffer = new Buffer.from((!)res); if (buffer == null) return ErrorCode.NOMEM; return 1; } - static int ss_get_sub_device_sessions_func(out IntList sessions, char[] name, void* user_data) { + static int ss_get_sub_device_sessions_func(out IntList? sessions, char[] name, void* user_data) { Store store = (Store) user_data; try { sessions = store.session_store.get_sub_device_sessions(carr_to_string(name)); @@ -206,9 +206,9 @@ public class Store : Object { return 0; } - static int pks_load_pre_key(out Buffer record, uint32 pre_key_id, void* user_data) { + static int pks_load_pre_key(out Buffer? record, uint32 pre_key_id, void* user_data) { Store store = (Store) user_data; - uint8[] res = null; + uint8[]? res = null; try { res = store.pre_key_store.load_pre_key(pre_key_id); } catch (Error e) { @@ -219,7 +219,7 @@ public class Store : Object { record = new Buffer(0); return 0; } - record = new Buffer.from(res); + record = new Buffer.from((!)res); if (record == null) return ErrorCode.NOMEM; return 1; } @@ -251,9 +251,9 @@ public class Store : Object { return 0; } - static int spks_load_signed_pre_key(out Buffer record, uint32 pre_key_id, void* user_data) { + static int spks_load_signed_pre_key(out Buffer? record, uint32 pre_key_id, void* user_data) { Store store = (Store) user_data; - uint8[] res = null; + uint8[]? res = null; try { res = store.signed_pre_key_store.load_signed_pre_key(pre_key_id); } catch (Error e) { @@ -264,7 +264,7 @@ public class Store : Object { record = new Buffer(0); return 0; } - record = new Buffer.from(res); + record = new Buffer.from((!)res); if (record == null) return ErrorCode.NOMEM; return 1; } diff --git a/plugins/signal-protocol/tests/session_builder.vala b/plugins/signal-protocol/tests/session_builder.vala index 4cc7a581..246cbd1c 100644 --- a/plugins/signal-protocol/tests/session_builder.vala +++ b/plugins/signal-protocol/tests/session_builder.vala @@ -18,12 +18,8 @@ class SessionBuilderTest : Gee.TestCase { public override void set_up() { try { global_context = new Context(); - alice_address = new Address(); - alice_address.name = "+14151111111"; - alice_address.device_id = 1; - bob_address = new Address(); - bob_address.name = "+14152222222"; - bob_address.device_id = 1; + alice_address = new Address("+14151111111", 1); + bob_address = new Address("+14152222222", 1); } catch (Error e) { fail_if_reached(@"Unexpected error: $(e.message)"); } diff --git a/plugins/signal-protocol/vapi/signal-protocol-public.vapi b/plugins/signal-protocol/vapi/signal-protocol-public.vapi index cd19549f..acdf36a3 100644 --- a/plugins/signal-protocol/vapi/signal-protocol-public.vapi +++ b/plugins/signal-protocol/vapi/signal-protocol-public.vapi @@ -51,7 +51,7 @@ namespace Signal { } [Compact] - [CCode (cname = "signal_type_base", ref_function="signal_type_ref", ref_function_void=true, unref_function="signal_type_unref", cheader_filename="signal_protocol_types.h,signal_helper.h")] + [CCode (cname = "signal_type_base", ref_function="signal_type_ref_vapi", unref_function="signal_type_unref_vapi", cheader_filename="signal_protocol_types.h,signal_helper.h")] public class TypeBase { } @@ -103,8 +103,8 @@ namespace Signal { [Compact] [CCode (cname = "session_pre_key_bundle", cprefix = "session_pre_key_bundle_", cheader_filename = "session_pre_key.h")] public class PreKeyBundle : TypeBase { - public static int create(out PreKeyBundle bundle, uint32 registration_id, int device_id, uint32 pre_key_id, ECPublicKey pre_key_public, - uint32 signed_pre_key_id, ECPublicKey signed_pre_key_public, uint8[] signed_pre_key_signature, ECPublicKey identity_key); + public static int create(out PreKeyBundle bundle, uint32 registration_id, int device_id, uint32 pre_key_id, ECPublicKey? pre_key_public, + uint32 signed_pre_key_id, ECPublicKey? signed_pre_key_public, uint8[]? signed_pre_key_signature, ECPublicKey? identity_key); public uint32 registration_id { get; } public int device_id { get; } public uint32 pre_key_id { get; } @@ -192,9 +192,8 @@ namespace Signal { [Compact] [CCode (cname = "signal_protocol_address", cprefix = "signal_protocol_address_", cheader_filename = "signal_protocol.h,signal_helper.h")] public class Address { - public Address(); - public int32 device_id; - + public Address(string name, int32 device_id); + public int32 device_id { get; set; } public string name { owned get; set; } } -- cgit v1.2.3-70-g09d2