mirror of
https://github.com/codingo/Interlace.git
synced 2025-12-17 06:44:23 +01:00
Merge pull request #121 from whiteroses/fix-out-of-memory-issue-119
Improve memory usage in the generation of tasks (to fix issue #119)
This commit is contained in:
@@ -7,13 +7,18 @@ from Interlace.lib.core.output import OutputHelper, Level
|
|||||||
from Interlace.lib.threader import Pool
|
from Interlace.lib.threader import Pool
|
||||||
|
|
||||||
|
|
||||||
def build_queue(arguments, output, repeat):
|
def task_queue_generator_func(arguments, output, repeat):
|
||||||
task_list = InputHelper.process_commands(arguments)
|
tasks_data = InputHelper.process_data_for_tasks_iterator(arguments)
|
||||||
for task in task_list:
|
tasks_count = tasks_data["tasks_count"]
|
||||||
output.terminal(Level.THREAD, task.name(), "Added to Queue")
|
yield tasks_count
|
||||||
print('Generated {} commands in total'.format(len(task_list)))
|
tasks_generator_func = InputHelper.make_tasks_generator_func(tasks_data)
|
||||||
|
for i in range(repeat):
|
||||||
|
tasks_iterator = tasks_generator_func()
|
||||||
|
for task in tasks_iterator:
|
||||||
|
output.terminal(Level.THREAD, task.name(), "Added to Queue")
|
||||||
|
yield task
|
||||||
|
print('Generated {} commands in total'.format(tasks_count))
|
||||||
print('Repeat set to {}'.format(repeat))
|
print('Repeat set to {}'.format(repeat))
|
||||||
return task_list * repeat
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -27,9 +32,14 @@ def main():
|
|||||||
repeat = int(arguments.repeat)
|
repeat = int(arguments.repeat)
|
||||||
else:
|
else:
|
||||||
repeat = 1
|
repeat = 1
|
||||||
|
|
||||||
|
pool = Pool(
|
||||||
pool = Pool(arguments.threads, build_queue(arguments, output, repeat), arguments.timeout, output, arguments.sober)
|
arguments.threads,
|
||||||
|
task_queue_generator_func(arguments, output, repeat),
|
||||||
|
arguments.timeout,
|
||||||
|
output,
|
||||||
|
arguments.sober,
|
||||||
|
)
|
||||||
pool.run()
|
pool.run()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.8.2'
|
__version__ = '1.9.0'
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
|
import functools
|
||||||
|
import itertools
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from math import ceil
|
from random import choice
|
||||||
from random import sample, choice
|
|
||||||
|
|
||||||
from netaddr import IPNetwork, IPRange, IPGlob
|
from netaddr import (
|
||||||
|
IPRange,
|
||||||
|
IPSet,
|
||||||
|
glob_to_iprange,
|
||||||
|
)
|
||||||
|
|
||||||
from Interlace.lib.threader import Task
|
from Interlace.lib.threader import Task
|
||||||
|
|
||||||
@@ -45,42 +50,6 @@ class InputHelper(object):
|
|||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_ips_from_range(ip_range):
|
|
||||||
ips = set()
|
|
||||||
ip_range = ip_range.split("-")
|
|
||||||
|
|
||||||
# parsing the above structure into an array and then making into an IP address with the end value
|
|
||||||
end_ip = ".".join(ip_range[0].split(".")[0:-1]) + "." + ip_range[1]
|
|
||||||
|
|
||||||
# creating an IPRange object to get all IPs in between
|
|
||||||
range_obj = IPRange(ip_range[0], end_ip)
|
|
||||||
|
|
||||||
for ip in range_obj:
|
|
||||||
ips.add(str(ip))
|
|
||||||
|
|
||||||
return ips
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_ips_from_glob(glob_ips):
|
|
||||||
ip_glob = IPGlob(glob_ips)
|
|
||||||
|
|
||||||
ips = set()
|
|
||||||
|
|
||||||
for ip in ip_glob:
|
|
||||||
ips.add(str(ip))
|
|
||||||
|
|
||||||
return ips
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_cidr_to_ips(cidr_range):
|
|
||||||
ips = set()
|
|
||||||
|
|
||||||
for ip in IPNetwork(cidr_range):
|
|
||||||
ips.add(str(ip))
|
|
||||||
|
|
||||||
return ips
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_port(port_type):
|
def _process_port(port_type):
|
||||||
if "," in port_type:
|
if "," in port_type:
|
||||||
@@ -146,178 +115,256 @@ class InputHelper(object):
|
|||||||
return task_block
|
return task_block
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pre_process_hosts(host_ranges, destination_set, arguments):
|
def _replace_target_variables_in_commands(tasks, str_targets, ipset_targets):
|
||||||
for host in host_ranges:
|
TARGET_VAR = "_target_"
|
||||||
host = host.replace(" ", "").replace("\n", "")
|
HOST_VAR = "_host_"
|
||||||
# check if it is a domain name
|
CLEANTARGET_VAR = "_cleantarget_"
|
||||||
if len(host.split(".")[0]) == 0:
|
for task in tasks:
|
||||||
destination_set.add(host)
|
command = task.name()
|
||||||
continue
|
if TARGET_VAR in command or HOST_VAR in command:
|
||||||
|
for dirty_target in itertools.chain(str_targets, ipset_targets):
|
||||||
|
yielded_task = task.clone()
|
||||||
|
dirty_target = str(dirty_target)
|
||||||
|
yielded_task.replace(TARGET_VAR, dirty_target)
|
||||||
|
yielded_task.replace(HOST_VAR, dirty_target)
|
||||||
|
yielded_task.replace(
|
||||||
|
CLEANTARGET_VAR,
|
||||||
|
dirty_target.replace("http://", "").replace(
|
||||||
|
"https://", "").rstrip("/").replace("/", "-"),
|
||||||
|
)
|
||||||
|
yield yielded_task
|
||||||
|
else:
|
||||||
|
yield task
|
||||||
|
|
||||||
if host.split(".")[0][0].isalpha() or host.split(".")[-1][-1].isalpha():
|
@staticmethod
|
||||||
destination_set.add(host)
|
def _replace_variable_in_commands(tasks_generator_func, variable, replacements):
|
||||||
continue
|
for task in tasks_generator_func():
|
||||||
for ips in host.split(","):
|
if variable in task.name():
|
||||||
# checking for CIDR
|
for replacement in replacements:
|
||||||
if not arguments.nocidr and "/" in ips:
|
yielded_task = task.clone()
|
||||||
destination_set.update(InputHelper._get_cidr_to_ips(ips))
|
yielded_task.replace(variable, str(replacement))
|
||||||
# checking for IPs in a range
|
yield yielded_task
|
||||||
elif "-" in ips:
|
else:
|
||||||
destination_set.update(InputHelper._get_ips_from_range(ips))
|
yield task
|
||||||
# checking for glob ranges
|
|
||||||
elif "*" in ips:
|
@staticmethod
|
||||||
destination_set.update(InputHelper._get_ips_from_glob(ips))
|
def _replace_variable_array(
|
||||||
|
tasks_generator_func, variable, replacements_iterator
|
||||||
|
):
|
||||||
|
for task in tasks_generator_func():
|
||||||
|
task.replace(variable, str(next(replacements_iterator)))
|
||||||
|
yield task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_targets(arguments):
|
||||||
|
def pre_process_target_spec(target_spec):
|
||||||
|
target_spec = "".join(
|
||||||
|
filter(lambda char: char not in (" ", "\n"), target_spec)
|
||||||
|
)
|
||||||
|
return target_spec.split(",")
|
||||||
|
# If ","s not in target_spec, this returns [target_spec], so this
|
||||||
|
# static method always returns a list
|
||||||
|
|
||||||
|
if arguments.target:
|
||||||
|
target_specs = pre_process_target_spec(arguments.target)
|
||||||
|
else:
|
||||||
|
target_specs_file = arguments.target_list
|
||||||
|
if not isinstance(target_specs_file, TextIOWrapper):
|
||||||
|
if not sys.stdin.isatty():
|
||||||
|
target_specs_file = sys.stdin
|
||||||
|
target_specs = (
|
||||||
|
target_spec.strip() for target_spec in target_specs_file
|
||||||
|
)
|
||||||
|
target_specs = (
|
||||||
|
pre_process_target_spec(target_spec) for target_spec in
|
||||||
|
target_specs if target_spec
|
||||||
|
)
|
||||||
|
target_specs = itertools.chain(*target_specs)
|
||||||
|
|
||||||
|
def parse_and_group_target_specs(target_specs, nocidr):
|
||||||
|
str_targets = set()
|
||||||
|
ipset_targets = IPSet()
|
||||||
|
for target_spec in target_specs:
|
||||||
|
if (
|
||||||
|
target_spec.startswith(".") or
|
||||||
|
(
|
||||||
|
(target_spec[0].isalpha() or target_spec[-1].isalpha())
|
||||||
|
and "." in target_spec
|
||||||
|
) or
|
||||||
|
(nocidr and "/" in target_spec)
|
||||||
|
):
|
||||||
|
str_targets.add(target_spec)
|
||||||
else:
|
else:
|
||||||
destination_set.add(ips)
|
if "-" in target_spec:
|
||||||
|
start_ip, post_dash_segment = target_spec.split("-")
|
||||||
|
end_ip = start_ip.rsplit(".", maxsplit=1)[0] + "." + \
|
||||||
|
post_dash_segment
|
||||||
|
target_spec = IPRange(start_ip, end_ip)
|
||||||
|
elif "*" in target_spec:
|
||||||
|
target_spec = glob_to_iprange(target_spec)
|
||||||
|
else: # str IP addresses and str CIDR notations
|
||||||
|
target_spec = (target_spec,)
|
||||||
|
ipset_targets.update(IPSet(target_spec))
|
||||||
|
return (str_targets, ipset_targets)
|
||||||
|
|
||||||
|
str_targets, ipset_targets = parse_and_group_target_specs(
|
||||||
|
target_specs=target_specs,
|
||||||
|
nocidr=arguments.nocidr,
|
||||||
|
)
|
||||||
|
|
||||||
|
if arguments.exclusions or arguments.exclusions_list:
|
||||||
|
if arguments.exclusions:
|
||||||
|
exclusion_specs = pre_process_target_spec(arguments.exclusions)
|
||||||
|
elif arguments.exclusions_list:
|
||||||
|
exclusion_specs = (
|
||||||
|
exclusion_spec.strip() for exclusion_spec in
|
||||||
|
arguments.exclusions_list
|
||||||
|
)
|
||||||
|
exclusion_specs = (
|
||||||
|
pre_process_target_spec(exclusion_spec) for exclusion_spec
|
||||||
|
in exclusion_specs if exclusion_spec
|
||||||
|
)
|
||||||
|
exclusion_specs = itertools.chain(*exclusion_specs)
|
||||||
|
str_exclusions, ipset_exclusions = parse_and_group_target_specs(
|
||||||
|
target_specs=exclusion_specs,
|
||||||
|
nocidr=arguments.nocidr,
|
||||||
|
)
|
||||||
|
str_targets -= str_exclusions
|
||||||
|
ipset_targets -= ipset_exclusions
|
||||||
|
|
||||||
|
return (str_targets, ipset_targets)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_clean_targets(commands, dirty_targets):
|
def process_data_for_tasks_iterator(arguments):
|
||||||
def add_task(t, item_list, my_command_set):
|
|
||||||
if t not in my_command_set:
|
|
||||||
my_command_set.add(t)
|
|
||||||
item_list.append(t)
|
|
||||||
|
|
||||||
variable = '_cleantarget_'
|
|
||||||
tasks = []
|
|
||||||
temp = set() # this helps avoid command duplication and re/deconstructing of temporary set
|
|
||||||
# changed order to ensure different combinations of commands aren't created
|
|
||||||
for dirty_target in dirty_targets:
|
|
||||||
for command in commands:
|
|
||||||
new_task = command.clone()
|
|
||||||
if command.name().find(variable) != -1:
|
|
||||||
new_task.replace("_target_", dirty_target)
|
|
||||||
|
|
||||||
# replace all https:// or https:// with nothing
|
|
||||||
dirty_target = dirty_target.replace('http://', '')
|
|
||||||
dirty_target = dirty_target.replace('https://', '')
|
|
||||||
# chop off all trailing '/', if any.
|
|
||||||
while dirty_target.endswith('/'):
|
|
||||||
dirty_target = dirty_target.strip('/')
|
|
||||||
# replace all remaining '/' with '-' and that's enough cleanup for the day
|
|
||||||
clean_target = dirty_target.replace('/', '-')
|
|
||||||
new_task.replace(variable, clean_target)
|
|
||||||
add_task(new_task, tasks, temp)
|
|
||||||
else:
|
|
||||||
new_task.replace("_target_", dirty_target)
|
|
||||||
add_task(new_task, tasks, temp)
|
|
||||||
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _replace_variable_with_commands(commands, variable, replacements):
|
|
||||||
def add_task(t, item_list, my_set):
|
|
||||||
if t not in my_set:
|
|
||||||
my_set.add(t)
|
|
||||||
item_list.append(t)
|
|
||||||
|
|
||||||
tasks = []
|
|
||||||
temp_set = set() # to avoid duplicates
|
|
||||||
for command in commands:
|
|
||||||
for replacement in replacements:
|
|
||||||
if command.name().find(variable) != -1:
|
|
||||||
new_task = command.clone()
|
|
||||||
new_task.replace(variable, str(replacement))
|
|
||||||
add_task(new_task, tasks, temp_set)
|
|
||||||
else:
|
|
||||||
add_task(command, tasks, temp_set)
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _replace_variable_array(commands, variable, replacement):
|
|
||||||
if variable not in sample(commands, 1)[0]:
|
|
||||||
return
|
|
||||||
|
|
||||||
for counter, command in enumerate(commands):
|
|
||||||
command.replace(variable, str(replacement[counter]))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_commands(arguments):
|
|
||||||
commands = list()
|
|
||||||
ranges = set()
|
|
||||||
targets = set()
|
|
||||||
exclusions_ranges = set()
|
|
||||||
exclusions = set()
|
|
||||||
|
|
||||||
# removing the trailing slash if any
|
# removing the trailing slash if any
|
||||||
if arguments.output and arguments.output[-1] == "/":
|
if arguments.output and arguments.output[-1] == "/":
|
||||||
arguments.output = arguments.output[:-1]
|
arguments.output = arguments.output[:-1]
|
||||||
|
|
||||||
if arguments.port:
|
ports = InputHelper._process_port(arguments.port) if arguments.port \
|
||||||
ports = InputHelper._process_port(arguments.port)
|
else None
|
||||||
|
|
||||||
if arguments.realport:
|
real_ports = InputHelper._process_port(arguments.realport) if \
|
||||||
real_ports = InputHelper._process_port(arguments.realport)
|
arguments.realport else None
|
||||||
|
|
||||||
# process targets first
|
str_targets, ipset_targets = InputHelper._process_targets(
|
||||||
if arguments.target:
|
arguments=arguments,
|
||||||
ranges.add(arguments.target)
|
)
|
||||||
else:
|
targets_count = len(str_targets) + ipset_targets.size
|
||||||
target_file = arguments.target_list
|
|
||||||
if not isinstance(target_file, TextIOWrapper):
|
|
||||||
if not sys.stdin.isatty():
|
|
||||||
target_file = sys.stdin
|
|
||||||
ranges.update([target.strip() for target in target_file if target.strip()])
|
|
||||||
|
|
||||||
# process exclusions first
|
if not targets_count:
|
||||||
if arguments.exclusions:
|
|
||||||
exclusions_ranges.add(arguments.exclusions)
|
|
||||||
else:
|
|
||||||
if arguments.exclusions_list:
|
|
||||||
for exclusion in arguments.exclusions_list:
|
|
||||||
exclusion = exclusion.strip()
|
|
||||||
if exclusion:
|
|
||||||
exclusions.add(exclusion)
|
|
||||||
|
|
||||||
# removing elements that may have spaces (helpful for easily processing comma notation)
|
|
||||||
InputHelper._pre_process_hosts(ranges, targets, arguments)
|
|
||||||
InputHelper._pre_process_hosts(exclusions_ranges, exclusions, arguments)
|
|
||||||
|
|
||||||
# difference operation
|
|
||||||
targets -= exclusions
|
|
||||||
|
|
||||||
if len(targets) == 0:
|
|
||||||
raise Exception("No target provided, or empty target list")
|
raise Exception("No target provided, or empty target list")
|
||||||
|
|
||||||
if arguments.random:
|
if arguments.random:
|
||||||
files = InputHelper._get_files_from_directory(arguments.random)
|
files = InputHelper._get_files_from_directory(arguments.random)
|
||||||
random_file = choice(files)
|
random_file = choice(files)
|
||||||
|
|
||||||
if arguments.command:
|
|
||||||
commands.append(Task(arguments.command.rstrip('\n')))
|
|
||||||
else:
|
else:
|
||||||
commands = InputHelper._pre_process_commands(arguments.command_list)
|
random_file = None
|
||||||
|
|
||||||
# commands = InputHelper._replace_variable_with_commands(commands, "_target_", targets)
|
tasks = list()
|
||||||
commands = InputHelper._process_clean_targets(commands, targets)
|
if arguments.command:
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_host_", targets)
|
tasks.append(Task(arguments.command.rstrip('\n')))
|
||||||
|
else:
|
||||||
if arguments.port:
|
tasks = InputHelper._pre_process_commands(arguments.command_list)
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_port_", ports)
|
|
||||||
|
|
||||||
if arguments.realport:
|
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_realport_", real_ports)
|
|
||||||
|
|
||||||
if arguments.random:
|
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_random_", [random_file])
|
|
||||||
|
|
||||||
if arguments.output:
|
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_output_", [arguments.output])
|
|
||||||
|
|
||||||
if arguments.proto:
|
if arguments.proto:
|
||||||
if "," in arguments.proto:
|
protocols = arguments.proto.split(",")
|
||||||
protocols = arguments.proto.split(",")
|
# if "," not in arguments.proto, [arguments.proto] is returned by
|
||||||
else:
|
# .split()
|
||||||
protocols = arguments.proto
|
else:
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_proto_", protocols)
|
protocols = None
|
||||||
|
|
||||||
# process proxies
|
# Calculate the tasks count, as we will not have access to the len() of
|
||||||
if arguments.proxy_list:
|
# the tasks iterator
|
||||||
proxy_list = [proxy for proxy in arguments.proxy_list if proxy.strip()]
|
tasks_count = len(tasks) * targets_count
|
||||||
if len(proxy_list) < len(commands):
|
if ports:
|
||||||
proxy_list = ceil(len(commands) / len(proxy_list)) * proxy_list
|
tasks_count *= len(ports)
|
||||||
|
if real_ports:
|
||||||
|
tasks_count *= len(real_ports)
|
||||||
|
if protocols:
|
||||||
|
tasks_count *= len(protocols)
|
||||||
|
|
||||||
InputHelper._replace_variable_array(commands, "_proxy_", proxy_list)
|
return {
|
||||||
return commands
|
"tasks": tasks,
|
||||||
|
"str_targets": str_targets,
|
||||||
|
"ipset_targets": ipset_targets,
|
||||||
|
"ports": ports,
|
||||||
|
"real_ports": real_ports,
|
||||||
|
"random_file": random_file,
|
||||||
|
"output": arguments.output,
|
||||||
|
"protocols": protocols,
|
||||||
|
"proxy_list": arguments.proxy_list,
|
||||||
|
"tasks_count": tasks_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_tasks_generator_func(tasks_data):
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_target_variables_in_commands,
|
||||||
|
tasks=tasks_data["tasks"],
|
||||||
|
str_targets=tasks_data["str_targets"],
|
||||||
|
ipset_targets=tasks_data["ipset_targets"],
|
||||||
|
)
|
||||||
|
|
||||||
|
ports = tasks_data["ports"]
|
||||||
|
if ports:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_port_",
|
||||||
|
replacements=ports,
|
||||||
|
)
|
||||||
|
|
||||||
|
real_ports = tasks_data["real_ports"]
|
||||||
|
if real_ports:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_realport_",
|
||||||
|
replacements=real_ports,
|
||||||
|
)
|
||||||
|
|
||||||
|
random_file = tasks_data["random_file"]
|
||||||
|
if random_file:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_random_",
|
||||||
|
replacements=[random_file],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = tasks_data["output"]
|
||||||
|
if output:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_output_",
|
||||||
|
replacements=[output],
|
||||||
|
)
|
||||||
|
|
||||||
|
protocols = tasks_data["protocols"]
|
||||||
|
if protocols:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_proto_",
|
||||||
|
replacements=protocols,
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_list = tasks_data["proxy_list"]
|
||||||
|
if proxy_list:
|
||||||
|
proxy_list_iterator = itertools.cycle(
|
||||||
|
proxy for proxy in (
|
||||||
|
proxy.strip() for proxy in proxy_list
|
||||||
|
) if proxy
|
||||||
|
)
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_array,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_proxy_",
|
||||||
|
replacements_iterator=proxy_list_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tasks_generator_func
|
||||||
|
|
||||||
|
|
||||||
class InputParser(object):
|
class InputParser(object):
|
||||||
|
|||||||
@@ -66,17 +66,17 @@ class Worker(object):
|
|||||||
self.tqdm = tqdm
|
self.tqdm = tqdm
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
|
queue = self.queue
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# get task from queue
|
task = next(queue)
|
||||||
task = self.queue.pop(0)
|
|
||||||
if isinstance(self.tqdm, tqdm):
|
if isinstance(self.tqdm, tqdm):
|
||||||
self.tqdm.update(1)
|
self.tqdm.update(1)
|
||||||
# run task
|
# run task
|
||||||
task.run(self.tqdm)
|
task.run(self.tqdm)
|
||||||
else:
|
else:
|
||||||
task.run()
|
task.run()
|
||||||
except IndexError:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@@ -90,17 +90,19 @@ class Pool(object):
|
|||||||
if max_workers <= 0:
|
if max_workers <= 0:
|
||||||
raise ValueError("Workers must be >= 1")
|
raise ValueError("Workers must be >= 1")
|
||||||
|
|
||||||
|
tasks_count = next(task_queue)
|
||||||
|
|
||||||
# check if the queue is empty
|
# check if the queue is empty
|
||||||
if not task_queue:
|
if not tasks_count:
|
||||||
raise ValueError("The queue is empty")
|
raise ValueError("The queue is empty")
|
||||||
|
|
||||||
self.queue = task_queue
|
self.queue = task_queue
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.output = output
|
self.output = output
|
||||||
self.max_workers = min(len(task_queue), max_workers)
|
self.max_workers = min(tasks_count, max_workers)
|
||||||
|
|
||||||
if not progress_bar:
|
if not progress_bar:
|
||||||
self.tqdm = tqdm(total=len(task_queue))
|
self.tqdm = tqdm(total=tasks_count)
|
||||||
else:
|
else:
|
||||||
self.tqdm = True
|
self.tqdm = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user