mirror of
https://github.com/aljazceru/recon-pipeline.git
synced 2025-12-19 15:24:28 +01:00
210 lines
8.5 KiB
Python
210 lines
8.5 KiB
Python
import sqlite3
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse
|
|
|
|
from cmd2 import ansi
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy import exc, or_, create_engine
|
|
from sqlalchemy.sql.expression import ClauseElement
|
|
|
|
from .base_model import Base
|
|
from .port_model import Port
|
|
from .nse_model import NSEResult
|
|
from .target_model import Target
|
|
from .nmap_model import NmapResult
|
|
from .endpoint_model import Endpoint
|
|
from .ip_address_model import IPAddress
|
|
from .technology_model import Technology
|
|
from .searchsploit_model import SearchsploitResult
|
|
from ..recon.helpers import get_ip_address_version, is_ip_address
|
|
|
|
|
|
class DBManager:
|
|
""" Class that encapsulates database transactions and queries """
|
|
|
|
def __init__(self, db_location):
|
|
self.location = Path(db_location).expanduser().resolve()
|
|
self.connection_string = f"sqlite:///{self.location}"
|
|
engine = create_engine(self.connection_string)
|
|
Base.metadata.create_all(engine) # noqa: F405
|
|
session_factory = sessionmaker(bind=engine)
|
|
self.session = session_factory()
|
|
|
|
def get_or_create(self, model, **kwargs):
|
|
""" Simple helper to either get an existing record if it exists otherwise create and return a new instance """
|
|
instance = self.session.query(model).filter_by(**kwargs).first()
|
|
if instance:
|
|
return instance
|
|
else:
|
|
params = dict((k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement))
|
|
instance = model(**params)
|
|
return instance
|
|
|
|
def add(self, item):
|
|
""" Simple helper to add a record to the database """
|
|
try:
|
|
self.session.add(item)
|
|
self.session.commit()
|
|
except (sqlite3.IntegrityError, exc.IntegrityError):
|
|
print(ansi.style("[-] unique key constraint handled, moving on...", fg="bright_white"))
|
|
self.session.rollback()
|
|
|
|
def get_or_create_target_by_ip_or_hostname(self, ip_or_host):
|
|
""" Simple helper to query a Target record by either hostname or ip address, whichever works """
|
|
# get existing instance
|
|
instance = (
|
|
self.session.query(Target)
|
|
.filter(
|
|
or_(
|
|
Target.ip_addresses.any(
|
|
or_(IPAddress.ipv4_address.in_([ip_or_host]), IPAddress.ipv6_address.in_([ip_or_host]))
|
|
),
|
|
Target.hostname == ip_or_host,
|
|
)
|
|
)
|
|
.first()
|
|
)
|
|
if instance:
|
|
return instance
|
|
else:
|
|
# create new entry
|
|
tgt = self.get_or_create(Target)
|
|
|
|
if get_ip_address_version(ip_or_host) == "4":
|
|
tgt.ip_addresses.append(IPAddress(ipv4_address=ip_or_host))
|
|
elif get_ip_address_version(ip_or_host) == "6":
|
|
tgt.ip_addresses.append(IPAddress(ipv6_address=ip_or_host))
|
|
else:
|
|
# we've already determined it's not an IP, only other possibility is a hostname
|
|
tgt.hostname = ip_or_host
|
|
|
|
return tgt
|
|
|
|
def get_all_hostnames(self) -> list:
|
|
""" Simple helper to return all hostnames from Target records """
|
|
return [x[0] for x in self.session.query(Target.hostname).filter(Target.hostname != None)] # noqa: E711
|
|
|
|
def get_all_ipv4_addresses(self) -> list:
|
|
""" Simple helper to return all ipv4 addresses from Target records """
|
|
query = self.session.query(IPAddress.ipv4_address).filter(IPAddress.ipv4_address != None) # noqa: E711
|
|
return [x[0] for x in query]
|
|
|
|
def get_all_ipv6_addresses(self) -> list:
|
|
""" Simple helper to return all ipv6 addresses from Target records """
|
|
query = self.session.query(IPAddress.ipv6_address).filter(IPAddress.ipv6_address != None) # noqa: E711
|
|
return [x[0] for x in query]
|
|
|
|
def close(self):
|
|
""" Simple helper to close the database session """
|
|
self.session.close()
|
|
|
|
def get_all_targets(self):
|
|
""" Simple helper to return all ipv4/6 and hostnames produced by running amass """
|
|
return self.get_all_hostnames() + self.get_all_ipv4_addresses() + self.get_all_ipv6_addresses()
|
|
|
|
def get_all_endpoints(self):
|
|
""" Simple helper that returns all Endpoints from the database """
|
|
return self.session.query(Endpoint).all()
|
|
|
|
def get_all_port_numbers(self):
|
|
""" Simple helper that returns all Port.port_numbers from the database """
|
|
return set(str(x[0]) for x in self.session.query(Port.port_number).all())
|
|
|
|
def get_endpoint_by_status_code(self, code):
|
|
""" Simple helper that returns all Endpoints filtered by status code """
|
|
return self.session.query(Endpoint).filter(Endpoint.status_code == code).all()
|
|
|
|
def get_endpoints_by_ip_or_hostname(self, ip_or_host):
|
|
""" Simple helper that returns all Endpoints filtered by ip or hostname """
|
|
endpoints = list()
|
|
|
|
tmp_endpoints = self.session.query(Endpoint).filter(Endpoint.url.contains(ip_or_host)).all()
|
|
|
|
for ep in tmp_endpoints:
|
|
parsed_url = urlparse(ep.url)
|
|
if parsed_url.hostname == ip_or_host:
|
|
endpoints.append(ep)
|
|
|
|
return endpoints
|
|
|
|
def get_nmap_scans_by_ip_or_hostname(self, ip_or_host):
|
|
""" Simple helper that returns all Endpoints filtered by ip or hostname """
|
|
scans = list()
|
|
|
|
for result in self.session.query(NmapResult).filter(NmapResult.commandline.contains(ip_or_host)).all():
|
|
if result.commandline.split()[-1] == ip_or_host:
|
|
scans.append(result)
|
|
|
|
return scans
|
|
|
|
def get_status_codes(self):
|
|
""" Simple helper that returns all status codes found during scanning """
|
|
return set(str(x[0]) for x in self.session.query(Endpoint.status_code).all() if x[0] is not None)
|
|
|
|
def get_and_filter(self, model, defaults=None, **kwargs):
|
|
""" Simple helper to either get an existing record if it exists otherwise create and return a new instance """
|
|
return self.session.query(model).filter_by(**kwargs).all()
|
|
|
|
def get_all_nse_script_types(self):
|
|
""" Simple helper that returns all NSE Script types from the database """
|
|
return set(str(x[0]) for x in self.session.query(NSEResult.script_id).all())
|
|
|
|
def get_all_nmap_reported_products(self):
|
|
""" Simple helper that returns all products reported by nmap """
|
|
return set(str(x[0]) for x in self.session.query(NmapResult.product).all())
|
|
|
|
def get_all_exploit_types(self):
|
|
""" Simple helper that returns all exploit types reported by searchsploit """
|
|
return set(str(x[0]) for x in self.session.query(SearchsploitResult.type).all())
|
|
|
|
def add_ipv4_or_v6_address_to_target(self, tgt, ipaddr):
|
|
""" Simple helper that adds an appropriate IPAddress to the given target """
|
|
if not is_ip_address(ipaddr):
|
|
return
|
|
|
|
if get_ip_address_version(ipaddr) == "4":
|
|
ip_address = self.get_or_create(IPAddress, ipv4_address=ipaddr)
|
|
else:
|
|
ip_address = self.get_or_create(IPAddress, ipv6_address=ipaddr)
|
|
|
|
tgt.ip_addresses.append(ip_address)
|
|
|
|
return tgt
|
|
|
|
def get_all_web_targets(self):
|
|
""" Simple helper that returns all Targets tagged as having an open web port """
|
|
# return set(str(x[0]) for x in self.session.query(Target).all())
|
|
web_targets = list()
|
|
targets = self.get_and_filter(Target, is_web=True)
|
|
|
|
for target in targets:
|
|
if target.hostname:
|
|
web_targets.append(target.hostname)
|
|
for ipaddr in target.ip_addresses:
|
|
if ipaddr.ipv4_address:
|
|
web_targets.append(ipaddr.ipv4_address)
|
|
if ipaddr.ipv6_address:
|
|
web_targets.append(ipaddr.ipv6_address)
|
|
|
|
return web_targets
|
|
|
|
def get_ports_by_ip_or_host_and_protocol(self, ip_or_host, protocol):
|
|
""" Simple helper that returns all ports based on the given protocol and host """
|
|
tgt = self.get_or_create_target_by_ip_or_hostname(ip_or_host)
|
|
ports = list()
|
|
|
|
for port in tgt.open_ports:
|
|
if port.protocol == protocol:
|
|
ports.append(str(port.port_number))
|
|
|
|
return ports
|
|
|
|
def get_all_searchsploit_results(self):
|
|
return self.get_and_filter(SearchsploitResult)
|
|
|
|
def get_all_web_technology_types(self):
|
|
return set(str(x[0]) for x in self.session.query(Technology.type).all())
|
|
|
|
def get_all_web_technology_products(self):
|
|
return set(str(x[0]) for x in self.session.query(Technology.text).all())
|