Merge pull request #121 from whiteroses/fix-out-of-memory-issue-119

Improve memory usage in the generation of tasks (to fix issue #119)
This commit is contained in:
Sajeeb Lohani
2020-10-03 23:20:58 +10:00
committed by GitHub
4 changed files with 264 additions and 205 deletions

View File

@@ -7,13 +7,18 @@ from Interlace.lib.core.output import OutputHelper, Level
from Interlace.lib.threader import Pool from Interlace.lib.threader import Pool
def build_queue(arguments, output, repeat): def task_queue_generator_func(arguments, output, repeat):
task_list = InputHelper.process_commands(arguments) tasks_data = InputHelper.process_data_for_tasks_iterator(arguments)
for task in task_list: tasks_count = tasks_data["tasks_count"]
yield tasks_count
tasks_generator_func = InputHelper.make_tasks_generator_func(tasks_data)
for i in range(repeat):
tasks_iterator = tasks_generator_func()
for task in tasks_iterator:
output.terminal(Level.THREAD, task.name(), "Added to Queue") output.terminal(Level.THREAD, task.name(), "Added to Queue")
print('Generated {} commands in total'.format(len(task_list))) yield task
print('Generated {} commands in total'.format(tasks_count))
print('Repeat set to {}'.format(repeat)) print('Repeat set to {}'.format(repeat))
return task_list * repeat
def main(): def main():
@@ -28,8 +33,13 @@ def main():
else: else:
repeat = 1 repeat = 1
pool = Pool(
pool = Pool(arguments.threads, build_queue(arguments, output, repeat), arguments.timeout, output, arguments.sober) arguments.threads,
task_queue_generator_func(arguments, output, repeat),
arguments.timeout,
output,
arguments.sober,
)
pool.run() pool.run()

View File

@@ -1 +1 @@
__version__ = '1.8.2' __version__ = '1.9.0'

View File

@@ -1,11 +1,16 @@
import functools
import itertools
import os.path import os.path
import sys import sys
from io import TextIOWrapper from io import TextIOWrapper
from argparse import ArgumentParser from argparse import ArgumentParser
from math import ceil from random import choice
from random import sample, choice
from netaddr import IPNetwork, IPRange, IPGlob from netaddr import (
IPRange,
IPSet,
glob_to_iprange,
)
from Interlace.lib.threader import Task from Interlace.lib.threader import Task
@@ -45,42 +50,6 @@ class InputHelper(object):
return files return files
@staticmethod
def _get_ips_from_range(ip_range):
ips = set()
ip_range = ip_range.split("-")
# parsing the above structure into an array and then making into an IP address with the end value
end_ip = ".".join(ip_range[0].split(".")[0:-1]) + "." + ip_range[1]
# creating an IPRange object to get all IPs in between
range_obj = IPRange(ip_range[0], end_ip)
for ip in range_obj:
ips.add(str(ip))
return ips
@staticmethod
def _get_ips_from_glob(glob_ips):
ip_glob = IPGlob(glob_ips)
ips = set()
for ip in ip_glob:
ips.add(str(ip))
return ips
@staticmethod
def _get_cidr_to_ips(cidr_range):
ips = set()
for ip in IPNetwork(cidr_range):
ips.add(str(ip))
return ips
@staticmethod @staticmethod
def _process_port(port_type): def _process_port(port_type):
if "," in port_type: if "," in port_type:
@@ -146,178 +115,256 @@ class InputHelper(object):
return task_block return task_block
@staticmethod @staticmethod
def _pre_process_hosts(host_ranges, destination_set, arguments): def _replace_target_variables_in_commands(tasks, str_targets, ipset_targets):
for host in host_ranges: TARGET_VAR = "_target_"
host = host.replace(" ", "").replace("\n", "") HOST_VAR = "_host_"
# check if it is a domain name CLEANTARGET_VAR = "_cleantarget_"
if len(host.split(".")[0]) == 0: for task in tasks:
destination_set.add(host) command = task.name()
continue if TARGET_VAR in command or HOST_VAR in command:
for dirty_target in itertools.chain(str_targets, ipset_targets):
if host.split(".")[0][0].isalpha() or host.split(".")[-1][-1].isalpha(): yielded_task = task.clone()
destination_set.add(host) dirty_target = str(dirty_target)
continue yielded_task.replace(TARGET_VAR, dirty_target)
for ips in host.split(","): yielded_task.replace(HOST_VAR, dirty_target)
# checking for CIDR yielded_task.replace(
if not arguments.nocidr and "/" in ips: CLEANTARGET_VAR,
destination_set.update(InputHelper._get_cidr_to_ips(ips)) dirty_target.replace("http://", "").replace(
# checking for IPs in a range "https://", "").rstrip("/").replace("/", "-"),
elif "-" in ips: )
destination_set.update(InputHelper._get_ips_from_range(ips)) yield yielded_task
# checking for glob ranges
elif "*" in ips:
destination_set.update(InputHelper._get_ips_from_glob(ips))
else: else:
destination_set.add(ips) yield task
@staticmethod @staticmethod
def _process_clean_targets(commands, dirty_targets): def _replace_variable_in_commands(tasks_generator_func, variable, replacements):
def add_task(t, item_list, my_command_set): for task in tasks_generator_func():
if t not in my_command_set: if variable in task.name():
my_command_set.add(t)
item_list.append(t)
variable = '_cleantarget_'
tasks = []
temp = set() # this helps avoid command duplication and re/deconstructing of temporary set
# changed order to ensure different combinations of commands aren't created
for dirty_target in dirty_targets:
for command in commands:
new_task = command.clone()
if command.name().find(variable) != -1:
new_task.replace("_target_", dirty_target)
# replace all https:// or https:// with nothing
dirty_target = dirty_target.replace('http://', '')
dirty_target = dirty_target.replace('https://', '')
# chop off all trailing '/', if any.
while dirty_target.endswith('/'):
dirty_target = dirty_target.strip('/')
# replace all remaining '/' with '-' and that's enough cleanup for the day
clean_target = dirty_target.replace('/', '-')
new_task.replace(variable, clean_target)
add_task(new_task, tasks, temp)
else:
new_task.replace("_target_", dirty_target)
add_task(new_task, tasks, temp)
return tasks
@staticmethod
def _replace_variable_with_commands(commands, variable, replacements):
def add_task(t, item_list, my_set):
if t not in my_set:
my_set.add(t)
item_list.append(t)
tasks = []
temp_set = set() # to avoid duplicates
for command in commands:
for replacement in replacements: for replacement in replacements:
if command.name().find(variable) != -1: yielded_task = task.clone()
new_task = command.clone() yielded_task.replace(variable, str(replacement))
new_task.replace(variable, str(replacement)) yield yielded_task
add_task(new_task, tasks, temp_set)
else: else:
add_task(command, tasks, temp_set) yield task
return tasks
@staticmethod @staticmethod
def _replace_variable_array(commands, variable, replacement): def _replace_variable_array(
if variable not in sample(commands, 1)[0]: tasks_generator_func, variable, replacements_iterator
return ):
for task in tasks_generator_func():
for counter, command in enumerate(commands): task.replace(variable, str(next(replacements_iterator)))
command.replace(variable, str(replacement[counter])) yield task
@staticmethod @staticmethod
def process_commands(arguments): def _process_targets(arguments):
commands = list() def pre_process_target_spec(target_spec):
ranges = set() target_spec = "".join(
targets = set() filter(lambda char: char not in (" ", "\n"), target_spec)
exclusions_ranges = set() )
exclusions = set() return target_spec.split(",")
# If ","s not in target_spec, this returns [target_spec], so this
# static method always returns a list
if arguments.target:
target_specs = pre_process_target_spec(arguments.target)
else:
target_specs_file = arguments.target_list
if not isinstance(target_specs_file, TextIOWrapper):
if not sys.stdin.isatty():
target_specs_file = sys.stdin
target_specs = (
target_spec.strip() for target_spec in target_specs_file
)
target_specs = (
pre_process_target_spec(target_spec) for target_spec in
target_specs if target_spec
)
target_specs = itertools.chain(*target_specs)
def parse_and_group_target_specs(target_specs, nocidr):
str_targets = set()
ipset_targets = IPSet()
for target_spec in target_specs:
if (
target_spec.startswith(".") or
(
(target_spec[0].isalpha() or target_spec[-1].isalpha())
and "." in target_spec
) or
(nocidr and "/" in target_spec)
):
str_targets.add(target_spec)
else:
if "-" in target_spec:
start_ip, post_dash_segment = target_spec.split("-")
end_ip = start_ip.rsplit(".", maxsplit=1)[0] + "." + \
post_dash_segment
target_spec = IPRange(start_ip, end_ip)
elif "*" in target_spec:
target_spec = glob_to_iprange(target_spec)
else: # str IP addresses and str CIDR notations
target_spec = (target_spec,)
ipset_targets.update(IPSet(target_spec))
return (str_targets, ipset_targets)
str_targets, ipset_targets = parse_and_group_target_specs(
target_specs=target_specs,
nocidr=arguments.nocidr,
)
if arguments.exclusions or arguments.exclusions_list:
if arguments.exclusions:
exclusion_specs = pre_process_target_spec(arguments.exclusions)
elif arguments.exclusions_list:
exclusion_specs = (
exclusion_spec.strip() for exclusion_spec in
arguments.exclusions_list
)
exclusion_specs = (
pre_process_target_spec(exclusion_spec) for exclusion_spec
in exclusion_specs if exclusion_spec
)
exclusion_specs = itertools.chain(*exclusion_specs)
str_exclusions, ipset_exclusions = parse_and_group_target_specs(
target_specs=exclusion_specs,
nocidr=arguments.nocidr,
)
str_targets -= str_exclusions
ipset_targets -= ipset_exclusions
return (str_targets, ipset_targets)
@staticmethod
def process_data_for_tasks_iterator(arguments):
# removing the trailing slash if any # removing the trailing slash if any
if arguments.output and arguments.output[-1] == "/": if arguments.output and arguments.output[-1] == "/":
arguments.output = arguments.output[:-1] arguments.output = arguments.output[:-1]
if arguments.port: ports = InputHelper._process_port(arguments.port) if arguments.port \
ports = InputHelper._process_port(arguments.port) else None
if arguments.realport: real_ports = InputHelper._process_port(arguments.realport) if \
real_ports = InputHelper._process_port(arguments.realport) arguments.realport else None
# process targets first str_targets, ipset_targets = InputHelper._process_targets(
if arguments.target: arguments=arguments,
ranges.add(arguments.target) )
else: targets_count = len(str_targets) + ipset_targets.size
target_file = arguments.target_list
if not isinstance(target_file, TextIOWrapper):
if not sys.stdin.isatty():
target_file = sys.stdin
ranges.update([target.strip() for target in target_file if target.strip()])
# process exclusions first if not targets_count:
if arguments.exclusions:
exclusions_ranges.add(arguments.exclusions)
else:
if arguments.exclusions_list:
for exclusion in arguments.exclusions_list:
exclusion = exclusion.strip()
if exclusion:
exclusions.add(exclusion)
# removing elements that may have spaces (helpful for easily processing comma notation)
InputHelper._pre_process_hosts(ranges, targets, arguments)
InputHelper._pre_process_hosts(exclusions_ranges, exclusions, arguments)
# difference operation
targets -= exclusions
if len(targets) == 0:
raise Exception("No target provided, or empty target list") raise Exception("No target provided, or empty target list")
if arguments.random: if arguments.random:
files = InputHelper._get_files_from_directory(arguments.random) files = InputHelper._get_files_from_directory(arguments.random)
random_file = choice(files) random_file = choice(files)
if arguments.command:
commands.append(Task(arguments.command.rstrip('\n')))
else: else:
commands = InputHelper._pre_process_commands(arguments.command_list) random_file = None
# commands = InputHelper._replace_variable_with_commands(commands, "_target_", targets) tasks = list()
commands = InputHelper._process_clean_targets(commands, targets) if arguments.command:
commands = InputHelper._replace_variable_with_commands(commands, "_host_", targets) tasks.append(Task(arguments.command.rstrip('\n')))
else:
if arguments.port: tasks = InputHelper._pre_process_commands(arguments.command_list)
commands = InputHelper._replace_variable_with_commands(commands, "_port_", ports)
if arguments.realport:
commands = InputHelper._replace_variable_with_commands(commands, "_realport_", real_ports)
if arguments.random:
commands = InputHelper._replace_variable_with_commands(commands, "_random_", [random_file])
if arguments.output:
commands = InputHelper._replace_variable_with_commands(commands, "_output_", [arguments.output])
if arguments.proto: if arguments.proto:
if "," in arguments.proto:
protocols = arguments.proto.split(",") protocols = arguments.proto.split(",")
# if "," not in arguments.proto, [arguments.proto] is returned by
# .split()
else: else:
protocols = arguments.proto protocols = None
commands = InputHelper._replace_variable_with_commands(commands, "_proto_", protocols)
# process proxies # Calculate the tasks count, as we will not have access to the len() of
if arguments.proxy_list: # the tasks iterator
proxy_list = [proxy for proxy in arguments.proxy_list if proxy.strip()] tasks_count = len(tasks) * targets_count
if len(proxy_list) < len(commands): if ports:
proxy_list = ceil(len(commands) / len(proxy_list)) * proxy_list tasks_count *= len(ports)
if real_ports:
tasks_count *= len(real_ports)
if protocols:
tasks_count *= len(protocols)
InputHelper._replace_variable_array(commands, "_proxy_", proxy_list) return {
return commands "tasks": tasks,
"str_targets": str_targets,
"ipset_targets": ipset_targets,
"ports": ports,
"real_ports": real_ports,
"random_file": random_file,
"output": arguments.output,
"protocols": protocols,
"proxy_list": arguments.proxy_list,
"tasks_count": tasks_count,
}
@staticmethod
def make_tasks_generator_func(tasks_data):
tasks_generator_func = functools.partial(
InputHelper._replace_target_variables_in_commands,
tasks=tasks_data["tasks"],
str_targets=tasks_data["str_targets"],
ipset_targets=tasks_data["ipset_targets"],
)
ports = tasks_data["ports"]
if ports:
tasks_generator_func = functools.partial(
InputHelper._replace_variable_in_commands,
tasks_generator_func=tasks_generator_func,
variable="_port_",
replacements=ports,
)
real_ports = tasks_data["real_ports"]
if real_ports:
tasks_generator_func = functools.partial(
InputHelper._replace_variable_in_commands,
tasks_generator_func=tasks_generator_func,
variable="_realport_",
replacements=real_ports,
)
random_file = tasks_data["random_file"]
if random_file:
tasks_generator_func = functools.partial(
InputHelper._replace_variable_in_commands,
tasks_generator_func=tasks_generator_func,
variable="_random_",
replacements=[random_file],
)
output = tasks_data["output"]
if output:
tasks_generator_func = functools.partial(
InputHelper._replace_variable_in_commands,
tasks_generator_func=tasks_generator_func,
variable="_output_",
replacements=[output],
)
protocols = tasks_data["protocols"]
if protocols:
tasks_generator_func = functools.partial(
InputHelper._replace_variable_in_commands,
tasks_generator_func=tasks_generator_func,
variable="_proto_",
replacements=protocols,
)
proxy_list = tasks_data["proxy_list"]
if proxy_list:
proxy_list_iterator = itertools.cycle(
proxy for proxy in (
proxy.strip() for proxy in proxy_list
) if proxy
)
tasks_generator_func = functools.partial(
InputHelper._replace_variable_array,
tasks_generator_func=tasks_generator_func,
variable="_proxy_",
replacements_iterator=proxy_list_iterator,
)
return tasks_generator_func
class InputParser(object): class InputParser(object):

View File

@@ -66,17 +66,17 @@ class Worker(object):
self.tqdm = tqdm self.tqdm = tqdm
def __call__(self): def __call__(self):
queue = self.queue
while True: while True:
try: try:
# get task from queue task = next(queue)
task = self.queue.pop(0)
if isinstance(self.tqdm, tqdm): if isinstance(self.tqdm, tqdm):
self.tqdm.update(1) self.tqdm.update(1)
# run task # run task
task.run(self.tqdm) task.run(self.tqdm)
else: else:
task.run() task.run()
except IndexError: except StopIteration:
break break
@@ -90,17 +90,19 @@ class Pool(object):
if max_workers <= 0: if max_workers <= 0:
raise ValueError("Workers must be >= 1") raise ValueError("Workers must be >= 1")
tasks_count = next(task_queue)
# check if the queue is empty # check if the queue is empty
if not task_queue: if not tasks_count:
raise ValueError("The queue is empty") raise ValueError("The queue is empty")
self.queue = task_queue self.queue = task_queue
self.timeout = timeout self.timeout = timeout
self.output = output self.output = output
self.max_workers = min(len(task_queue), max_workers) self.max_workers = min(tasks_count, max_workers)
if not progress_bar: if not progress_bar:
self.tqdm = tqdm(total=len(task_queue)) self.tqdm = tqdm(total=tasks_count)
else: else:
self.tqdm = True self.tqdm = True