Skip to content

Commit

Permalink
[client] Improve STIX2 bundle splitter (#736)
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-julien authored Sep 24, 2024
1 parent fe0a730 commit 3f6fbb4
Show file tree
Hide file tree
Showing 12 changed files with 70,700 additions and 67 deletions.
25 changes: 13 additions & 12 deletions pycti/connector/opencti_connector_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,8 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list:
:type entities_types: list, optional
:param update: whether to updated data in the database, defaults to False
:type update: bool, optional
:param bypass_split: use to prevent splitting of the bundle. This option has been removed since 6.3 and is no longer used.
:type bypass_split: bool, optional
:raises ValueError: if the bundle is empty
:return: list of bundles
:rtype: list
Expand All @@ -1564,11 +1566,11 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list:
entities_types = kwargs.get("entities_types", None)
update = kwargs.get("update", False)
event_version = kwargs.get("event_version", None)
bypass_split = kwargs.get("bypass_split", False)
bypass_validation = kwargs.get("bypass_validation", False)
entity_id = kwargs.get("entity_id", None)
file_name = kwargs.get("file_name", None)
bundle_send_to_queue = kwargs.get("send_to_queue", self.bundle_send_to_queue)
cleanup_inconsistent_bundle = kwargs.get("cleanup_inconsistent_bundle", False)
bundle_send_to_directory = kwargs.get(
"send_to_directory", self.bundle_send_to_directory
)
Expand Down Expand Up @@ -1690,17 +1692,16 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list:
final_write_file = os.path.join(bundle_send_to_directory_path, bundle_file)
os.rename(write_file, final_write_file)

if bypass_split:
bundles = [bundle]
expectations_number = len(json.loads(bundle)["objects"])
else:
stix2_splitter = OpenCTIStix2Splitter()
(
expectations_number,
bundles,
) = stix2_splitter.split_bundle_with_expectations(
bundle, True, event_version
)
stix2_splitter = OpenCTIStix2Splitter()
(
expectations_number,
bundles,
) = stix2_splitter.split_bundle_with_expectations(
bundle=bundle,
use_json=True,
event_version=event_version,
cleanup_inconsistent_bundle=cleanup_inconsistent_bundle,
)

if len(bundles) == 0:
self.metric.inc("error_count")
Expand Down
10 changes: 4 additions & 6 deletions pycti/entities/opencti_kill_chain_phase.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# coding: utf-8

import json
import uuid

from stix2.canonicalization.Canonicalize import canonicalize
from pycti.utils.opencti_stix2_identifier import kill_chain_phase_generate_id


class KillChainPhase:
Expand All @@ -25,10 +24,9 @@ def __init__(self, opencti):

@staticmethod
def generate_id(phase_name, kill_chain_name):
data = {"phase_name": phase_name, "kill_chain_name": kill_chain_name}
data = canonicalize(data, utf8=False)
id = str(uuid.uuid5(uuid.UUID("00abedb4-aa42-466c-9c01-fed23315a9b7"), data))
return "kill-chain-phase--" + id
return kill_chain_phase_generate_id(
phase_name=phase_name, kill_chain_name=kill_chain_name
)

"""
List Kill-Chain-Phase objects
Expand Down
7 changes: 3 additions & 4 deletions pycti/utils/opencti_stix2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2619,10 +2619,9 @@ def import_bundle(
else None
)
stix2_splitter = OpenCTIStix2Splitter()
try:
bundles = stix2_splitter.split_bundle(stix_bundle, False, event_version)
except RecursionError:
bundles = [stix_bundle]
_, bundles = stix2_splitter.split_bundle_with_expectations(
stix_bundle, False, event_version
)
# Import every element in a specific order
imported_elements = []
for bundle in bundles:
Expand Down
22 changes: 22 additions & 0 deletions pycti/utils/opencti_stix2_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import uuid

from stix2.canonicalization.Canonicalize import canonicalize


def external_reference_generate_id(url=None, source_name=None, external_id=None):
if url is not None:
data = {"url": url}
elif source_name is not None and external_id is not None:
data = {"source_name": source_name, "external_id": external_id}
else:
return None
data = canonicalize(data, utf8=False)
id = str(uuid.uuid5(uuid.UUID("00abedb4-aa42-466c-9c01-fed23315a9b7"), data))
return "external-reference--" + id


def kill_chain_phase_generate_id(phase_name, kill_chain_name):
data = {"phase_name": phase_name, "kill_chain_name": kill_chain_name}
data = canonicalize(data, utf8=False)
id = str(uuid.uuid5(uuid.UUID("00abedb4-aa42-466c-9c01-fed23315a9b7"), data))
return "kill-chain-phase--" + id
187 changes: 153 additions & 34 deletions pycti/utils/opencti_stix2_splitter.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,187 @@
import json
import re
import uuid
from typing import Tuple

from typing_extensions import deprecated

MITRE_X_CAPEC = (
"x_capec_*" # https://github.com/mitre-attack/attack-stix-data/issues/34
from pycti.utils.opencti_stix2_identifier import (
external_reference_generate_id,
kill_chain_phase_generate_id,
)
unsupported_ref_patterns = [MITRE_X_CAPEC]
from pycti.utils.opencti_stix2_utils import (
STIX_CYBER_OBSERVABLE_MAPPING,
SUPPORTED_STIX_ENTITY_OBJECTS,
)

supported_types = (
SUPPORTED_STIX_ENTITY_OBJECTS # entities
+ list(STIX_CYBER_OBSERVABLE_MAPPING.keys()) # observables
+ ["relationship", "sighting"] # relationships
)


def is_id_supported(key):
id_type = key.split("--")[0]
return id_type in supported_types


class OpenCTIStix2Splitter:
def __init__(self):
self.cache_index = {}
self.cache_refs = {}
self.elements = []
self.unsupported_patterns = list(
map(lambda pattern: re.compile(pattern), unsupported_ref_patterns)
)

def is_ref_key_supported(self, key):
for pattern in self.unsupported_patterns:
if pattern.match(key):
return False
return True

def enlist_element(self, item_id, raw_data):
def enlist_element(
self, item_id, raw_data, cleanup_inconsistent_bundle, parent_acc
):
nb_deps = 1
if item_id not in raw_data:
return 0

existing_item = self.cache_index.get(item_id)
if existing_item is not None:
return existing_item["nb_deps"]
# Recursive enlist for every refs

item = raw_data[item_id]
if self.cache_refs.get(item_id) is None:
self.cache_refs[item_id] = []
for key in list(item.keys()):
value = item[key]
if key.endswith("_refs") and self.is_ref_key_supported(key):
# Recursive enlist for every refs
if key.endswith("_refs"):
to_keep = []
for element_ref in item[key]:
if element_ref != item_id:
nb_deps += self.enlist_element(element_ref, raw_data)
to_keep.append(element_ref)
# We need to check if this ref is not already a reference
is_missing_ref = raw_data.get(element_ref) is None
must_be_cleaned = is_missing_ref and cleanup_inconsistent_bundle
not_dependency_ref = (
self.cache_refs.get(element_ref) is None
or item_id not in self.cache_refs[element_ref]
)
# Prevent any self reference
if (
is_id_supported(element_ref)
and not must_be_cleaned
and element_ref not in parent_acc
and element_ref != item_id
and not_dependency_ref
):
self.cache_refs[item_id].append(element_ref)
nb_deps += self.enlist_element(
element_ref,
raw_data,
cleanup_inconsistent_bundle,
parent_acc + [element_ref],
)
if element_ref not in to_keep:
to_keep.append(element_ref)
item[key] = to_keep
elif key.endswith("_ref") and self.is_ref_key_supported(key):
if item[key] == item_id:
item[key] = None
elif key.endswith("_ref"):
is_missing_ref = raw_data.get(value) is None
must_be_cleaned = is_missing_ref and cleanup_inconsistent_bundle
not_dependency_ref = (
self.cache_refs.get(value) is None
or item_id not in self.cache_refs[value]
)
# Prevent any self reference
if (
value is not None
and not must_be_cleaned
and value not in parent_acc
and is_id_supported(value)
and value != item_id
and not_dependency_ref
):
self.cache_refs[item_id].append(value)
nb_deps += self.enlist_element(
value,
raw_data,
cleanup_inconsistent_bundle,
parent_acc + [value],
)
else:
# Need to handle the special case of recursive ref for created by ref
is_created_by_ref = key == "created_by_ref"
if is_created_by_ref:
is_marking = item["id"].startswith("marking-definition--")
if is_marking is False:
nb_deps += self.enlist_element(value, raw_data)
else:
nb_deps += self.enlist_element(value, raw_data)
item[key] = None
# Case for embedded elements (deduplicating and cleanup)
elif key == "external_references":
# specific case of splitting external references
# reference_ids = []
deduplicated_references = []
deduplicated_references_cache = {}
references = item[key]
for reference in references:
reference_id = external_reference_generate_id(
url=reference.get("url"),
source_name=reference.get("source_name"),
external_id=reference.get("external_id"),
)
if (
reference_id is not None
and deduplicated_references_cache.get(reference_id) is None
):
deduplicated_references_cache[reference_id] = reference_id
deduplicated_references.append(reference)
# - Needed for a future move of splitting the elements
# reference["id"] = reference_id
# reference["type"] = "External-Reference"
# raw_data[reference_id] = reference
# if reference_id not in reference_ids:
# reference_ids.append(reference_id)
# nb_deps += self.enlist_element(reference_id, raw_data)
item[key] = deduplicated_references
elif key == "kill_chain_phases":
# specific case of splitting kill_chain phases
# kill_chain_ids = []
deduplicated_kill_chain = []
deduplicated_kill_chain_cache = {}
kill_chains = item[key]
for kill_chain in kill_chains:
kill_chain_id = kill_chain_phase_generate_id(
kill_chain_name=kill_chain.get("kill_chain_name"),
phase_name=kill_chain.get("phase_name"),
)
if (
kill_chain_id is not None
and deduplicated_kill_chain_cache.get(kill_chain_id) is None
):
deduplicated_kill_chain_cache[kill_chain_id] = kill_chain_id
deduplicated_kill_chain.append(kill_chain)
# - Needed for a future move of splitting the elements
# kill_chain["id"] = kill_chain_id
# kill_chain["type"] = "Kill-Chain-Phase"
# raw_data[kill_chain_id] = kill_chain
# if kill_chain_id not in kill_chain_ids:
# kill_chain_ids.append(kill_chain_id)
# nb_deps += self.enlist_element(kill_chain_id, raw_data)
item[key] = deduplicated_kill_chain

# Get the final dep counting and add in cache
item["nb_deps"] = nb_deps
self.elements.append(item)
self.cache_index[item_id] = item # Put in cache
# Put in cache
if self.cache_index.get(item_id) is None:
# enlist only if compatible
if item["type"] == "relationship":
is_compatible = (
item["source_ref"] is not None and item["target_ref"] is not None
)
elif item["type"] == "sighting":
is_compatible = (
item["sighting_of_ref"] is not None
and len(item["where_sighted_refs"]) > 0
)
else:
is_compatible = is_id_supported(item_id)
if is_compatible:
self.elements.append(item)
self.cache_index[item_id] = item

return nb_deps

def split_bundle_with_expectations(
self, bundle, use_json=True, event_version=None
self,
bundle,
use_json=True,
event_version=None,
cleanup_inconsistent_bundle=False,
) -> Tuple[int, list]:
"""splits a valid stix2 bundle into a list of bundles"""
if use_json:
Expand All @@ -84,7 +203,7 @@ def split_bundle_with_expectations(
for item in bundle_data["objects"]:
raw_data[item["id"]] = item
for item in bundle_data["objects"]:
self.enlist_element(item["id"], raw_data)
self.enlist_element(item["id"], raw_data, cleanup_inconsistent_bundle, [])

# Build the bundles
bundles = []
Expand Down
Loading

0 comments on commit 3f6fbb4

Please sign in to comment.