mirror of
https://github.com/codingo/Interlace.git
synced 2025-12-18 23:34:19 +01:00
Improve memory usage in the generation of tasks (to fix issue #119).
Switch to generating tasks using iterators, and make other changes such as using netaddr's IPSet to store ranges of IP addresses, in order to reduce the use of memory where possible.
This commit is contained in:
@@ -7,13 +7,18 @@ from Interlace.lib.core.output import OutputHelper, Level
|
|||||||
from Interlace.lib.threader import Pool
|
from Interlace.lib.threader import Pool
|
||||||
|
|
||||||
|
|
||||||
def build_queue(arguments, output, repeat):
|
def task_queue_generator_func(arguments, output, repeat):
|
||||||
task_list = InputHelper.process_commands(arguments)
|
tasks_data = InputHelper.process_data_for_tasks_iterator(arguments)
|
||||||
for task in task_list:
|
tasks_count = tasks_data["tasks_count"]
|
||||||
output.terminal(Level.THREAD, task.name(), "Added to Queue")
|
yield tasks_count
|
||||||
print('Generated {} commands in total'.format(len(task_list)))
|
tasks_generator_func = InputHelper.make_tasks_generator_func(tasks_data)
|
||||||
|
for i in range(repeat):
|
||||||
|
tasks_iterator = tasks_generator_func()
|
||||||
|
for task in tasks_iterator:
|
||||||
|
output.terminal(Level.THREAD, task.name(), "Added to Queue")
|
||||||
|
yield task
|
||||||
|
print('Generated {} commands in total'.format(tasks_count))
|
||||||
print('Repeat set to {}'.format(repeat))
|
print('Repeat set to {}'.format(repeat))
|
||||||
return task_list * repeat
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -28,8 +33,13 @@ def main():
|
|||||||
else:
|
else:
|
||||||
repeat = 1
|
repeat = 1
|
||||||
|
|
||||||
|
pool = Pool(
|
||||||
pool = Pool(arguments.threads, build_queue(arguments, output, repeat), arguments.timeout, output, arguments.sober)
|
arguments.threads,
|
||||||
|
task_queue_generator_func(arguments, output, repeat),
|
||||||
|
arguments.timeout,
|
||||||
|
output,
|
||||||
|
arguments.sober,
|
||||||
|
)
|
||||||
pool.run()
|
pool.run()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
|
import functools
|
||||||
|
import itertools
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from math import ceil
|
from random import choice
|
||||||
from random import sample, choice
|
|
||||||
|
|
||||||
from netaddr import IPNetwork, IPRange, IPGlob
|
from netaddr import (
|
||||||
|
IPRange,
|
||||||
|
IPSet,
|
||||||
|
glob_to_iprange,
|
||||||
|
)
|
||||||
|
|
||||||
from Interlace.lib.threader import Task
|
from Interlace.lib.threader import Task
|
||||||
|
|
||||||
@@ -45,42 +50,6 @@ class InputHelper(object):
|
|||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_ips_from_range(ip_range):
|
|
||||||
ips = set()
|
|
||||||
ip_range = ip_range.split("-")
|
|
||||||
|
|
||||||
# parsing the above structure into an array and then making into an IP address with the end value
|
|
||||||
end_ip = ".".join(ip_range[0].split(".")[0:-1]) + "." + ip_range[1]
|
|
||||||
|
|
||||||
# creating an IPRange object to get all IPs in between
|
|
||||||
range_obj = IPRange(ip_range[0], end_ip)
|
|
||||||
|
|
||||||
for ip in range_obj:
|
|
||||||
ips.add(str(ip))
|
|
||||||
|
|
||||||
return ips
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_ips_from_glob(glob_ips):
|
|
||||||
ip_glob = IPGlob(glob_ips)
|
|
||||||
|
|
||||||
ips = set()
|
|
||||||
|
|
||||||
for ip in ip_glob:
|
|
||||||
ips.add(str(ip))
|
|
||||||
|
|
||||||
return ips
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_cidr_to_ips(cidr_range):
|
|
||||||
ips = set()
|
|
||||||
|
|
||||||
for ip in IPNetwork(cidr_range):
|
|
||||||
ips.add(str(ip))
|
|
||||||
|
|
||||||
return ips
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_port(port_type):
|
def _process_port(port_type):
|
||||||
if "," in port_type:
|
if "," in port_type:
|
||||||
@@ -146,178 +115,256 @@ class InputHelper(object):
|
|||||||
return task_block
|
return task_block
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pre_process_hosts(host_ranges, destination_set, arguments):
|
def _replace_target_variables_in_commands(tasks, str_targets, ipset_targets):
|
||||||
for host in host_ranges:
|
TARGET_VAR = "_target_"
|
||||||
host = host.replace(" ", "").replace("\n", "")
|
HOST_VAR = "_host_"
|
||||||
# check if it is a domain name
|
CLEANTARGET_VAR = "_cleantarget_"
|
||||||
if len(host.split(".")[0]) == 0:
|
for task in tasks:
|
||||||
destination_set.add(host)
|
command = task.name()
|
||||||
continue
|
if TARGET_VAR in command or HOST_VAR in command:
|
||||||
|
for dirty_target in itertools.chain(str_targets, ipset_targets):
|
||||||
|
yielded_task = task.clone()
|
||||||
|
dirty_target = str(dirty_target)
|
||||||
|
yielded_task.replace(TARGET_VAR, dirty_target)
|
||||||
|
yielded_task.replace(HOST_VAR, dirty_target)
|
||||||
|
yielded_task.replace(
|
||||||
|
CLEANTARGET_VAR,
|
||||||
|
dirty_target.replace("http://", "").replace(
|
||||||
|
"https://", "").rstrip("/").replace("/", "-"),
|
||||||
|
)
|
||||||
|
yield yielded_task
|
||||||
|
else:
|
||||||
|
yield task
|
||||||
|
|
||||||
if host.split(".")[0][0].isalpha() or host.split(".")[-1][-1].isalpha():
|
@staticmethod
|
||||||
destination_set.add(host)
|
def _replace_variable_in_commands(tasks_generator_func, variable, replacements):
|
||||||
continue
|
for task in tasks_generator_func():
|
||||||
for ips in host.split(","):
|
if variable in task.name():
|
||||||
# checking for CIDR
|
for replacement in replacements:
|
||||||
if not arguments.nocidr and "/" in ips:
|
yielded_task = task.clone()
|
||||||
destination_set.update(InputHelper._get_cidr_to_ips(ips))
|
yielded_task.replace(variable, str(replacement))
|
||||||
# checking for IPs in a range
|
yield yielded_task
|
||||||
elif "-" in ips:
|
else:
|
||||||
destination_set.update(InputHelper._get_ips_from_range(ips))
|
yield task
|
||||||
# checking for glob ranges
|
|
||||||
elif "*" in ips:
|
@staticmethod
|
||||||
destination_set.update(InputHelper._get_ips_from_glob(ips))
|
def _replace_variable_array(
|
||||||
|
tasks_generator_func, variable, replacements_iterator
|
||||||
|
):
|
||||||
|
for task in tasks_generator_func():
|
||||||
|
task.replace(variable, str(next(replacements_iterator)))
|
||||||
|
yield task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_targets(arguments):
|
||||||
|
def pre_process_target_spec(target_spec):
|
||||||
|
target_spec = "".join(
|
||||||
|
filter(lambda char: char not in (" ", "\n"), target_spec)
|
||||||
|
)
|
||||||
|
return target_spec.split(",")
|
||||||
|
# If ","s not in target_spec, this returns [target_spec], so this
|
||||||
|
# static method always returns a list
|
||||||
|
|
||||||
|
if arguments.target:
|
||||||
|
target_specs = pre_process_target_spec(arguments.target)
|
||||||
|
else:
|
||||||
|
target_specs_file = arguments.target_list
|
||||||
|
if not isinstance(target_specs_file, TextIOWrapper):
|
||||||
|
if not sys.stdin.isatty():
|
||||||
|
target_specs_file = sys.stdin
|
||||||
|
target_specs = (
|
||||||
|
target_spec.strip() for target_spec in target_specs_file
|
||||||
|
)
|
||||||
|
target_specs = (
|
||||||
|
pre_process_target_spec(target_spec) for target_spec in
|
||||||
|
target_specs if target_spec
|
||||||
|
)
|
||||||
|
target_specs = itertools.chain(*target_specs)
|
||||||
|
|
||||||
|
def parse_and_group_target_specs(target_specs, nocidr):
|
||||||
|
str_targets = set()
|
||||||
|
ipset_targets = IPSet()
|
||||||
|
for target_spec in target_specs:
|
||||||
|
if (
|
||||||
|
target_spec.startswith(".") or
|
||||||
|
(
|
||||||
|
(target_spec[0].isalpha() or target_spec[-1].isalpha())
|
||||||
|
and "." in target_spec
|
||||||
|
) or
|
||||||
|
(nocidr and "/" in target_spec)
|
||||||
|
):
|
||||||
|
str_targets.add(target_spec)
|
||||||
else:
|
else:
|
||||||
destination_set.add(ips)
|
if "-" in target_spec:
|
||||||
|
start_ip, post_dash_segment = target_spec.split("-")
|
||||||
|
end_ip = start_ip.rsplit(".", maxsplit=1)[0] + "." + \
|
||||||
|
post_dash_segment
|
||||||
|
target_spec = IPRange(start_ip, end_ip)
|
||||||
|
elif "*" in target_spec:
|
||||||
|
target_spec = glob_to_iprange(target_spec)
|
||||||
|
else: # str IP addresses and str CIDR notations
|
||||||
|
target_spec = (target_spec,)
|
||||||
|
ipset_targets.update(IPSet(target_spec))
|
||||||
|
return (str_targets, ipset_targets)
|
||||||
|
|
||||||
|
str_targets, ipset_targets = parse_and_group_target_specs(
|
||||||
|
target_specs=target_specs,
|
||||||
|
nocidr=arguments.nocidr,
|
||||||
|
)
|
||||||
|
|
||||||
|
if arguments.exclusions or arguments.exclusions_list:
|
||||||
|
if arguments.exclusions:
|
||||||
|
exclusion_specs = pre_process_target_spec(arguments.exclusions)
|
||||||
|
elif arguments.exclusions_list:
|
||||||
|
exclusion_specs = (
|
||||||
|
exclusion_spec.strip() for exclusion_spec in
|
||||||
|
arguments.exclusions_list
|
||||||
|
)
|
||||||
|
exclusion_specs = (
|
||||||
|
pre_process_target_spec(exclusion_spec) for exclusion_spec
|
||||||
|
in exclusion_specs if exclusion_spec
|
||||||
|
)
|
||||||
|
exclusion_specs = itertools.chain(*exclusion_specs)
|
||||||
|
str_exclusions, ipset_exclusions = parse_and_group_target_specs(
|
||||||
|
target_specs=exclusion_specs,
|
||||||
|
nocidr=arguments.nocidr,
|
||||||
|
)
|
||||||
|
str_targets -= str_exclusions
|
||||||
|
ipset_targets -= ipset_exclusions
|
||||||
|
|
||||||
|
return (str_targets, ipset_targets)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_clean_targets(commands, dirty_targets):
|
def process_data_for_tasks_iterator(arguments):
|
||||||
def add_task(t, item_list, my_command_set):
|
|
||||||
if t not in my_command_set:
|
|
||||||
my_command_set.add(t)
|
|
||||||
item_list.append(t)
|
|
||||||
|
|
||||||
variable = '_cleantarget_'
|
|
||||||
tasks = []
|
|
||||||
temp = set() # this helps avoid command duplication and re/deconstructing of temporary set
|
|
||||||
# changed order to ensure different combinations of commands aren't created
|
|
||||||
for dirty_target in dirty_targets:
|
|
||||||
for command in commands:
|
|
||||||
new_task = command.clone()
|
|
||||||
if command.name().find(variable) != -1:
|
|
||||||
new_task.replace("_target_", dirty_target)
|
|
||||||
|
|
||||||
# replace all https:// or https:// with nothing
|
|
||||||
dirty_target = dirty_target.replace('http://', '')
|
|
||||||
dirty_target = dirty_target.replace('https://', '')
|
|
||||||
# chop off all trailing '/', if any.
|
|
||||||
while dirty_target.endswith('/'):
|
|
||||||
dirty_target = dirty_target.strip('/')
|
|
||||||
# replace all remaining '/' with '-' and that's enough cleanup for the day
|
|
||||||
clean_target = dirty_target.replace('/', '-')
|
|
||||||
new_task.replace(variable, clean_target)
|
|
||||||
add_task(new_task, tasks, temp)
|
|
||||||
else:
|
|
||||||
new_task.replace("_target_", dirty_target)
|
|
||||||
add_task(new_task, tasks, temp)
|
|
||||||
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _replace_variable_with_commands(commands, variable, replacements):
|
|
||||||
def add_task(t, item_list, my_set):
|
|
||||||
if t not in my_set:
|
|
||||||
my_set.add(t)
|
|
||||||
item_list.append(t)
|
|
||||||
|
|
||||||
tasks = []
|
|
||||||
temp_set = set() # to avoid duplicates
|
|
||||||
for command in commands:
|
|
||||||
for replacement in replacements:
|
|
||||||
if command.name().find(variable) != -1:
|
|
||||||
new_task = command.clone()
|
|
||||||
new_task.replace(variable, str(replacement))
|
|
||||||
add_task(new_task, tasks, temp_set)
|
|
||||||
else:
|
|
||||||
add_task(command, tasks, temp_set)
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _replace_variable_array(commands, variable, replacement):
|
|
||||||
if variable not in sample(commands, 1)[0]:
|
|
||||||
return
|
|
||||||
|
|
||||||
for counter, command in enumerate(commands):
|
|
||||||
command.replace(variable, str(replacement[counter]))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_commands(arguments):
|
|
||||||
commands = list()
|
|
||||||
ranges = set()
|
|
||||||
targets = set()
|
|
||||||
exclusions_ranges = set()
|
|
||||||
exclusions = set()
|
|
||||||
|
|
||||||
# removing the trailing slash if any
|
# removing the trailing slash if any
|
||||||
if arguments.output and arguments.output[-1] == "/":
|
if arguments.output and arguments.output[-1] == "/":
|
||||||
arguments.output = arguments.output[:-1]
|
arguments.output = arguments.output[:-1]
|
||||||
|
|
||||||
if arguments.port:
|
ports = InputHelper._process_port(arguments.port) if arguments.port \
|
||||||
ports = InputHelper._process_port(arguments.port)
|
else None
|
||||||
|
|
||||||
if arguments.realport:
|
real_ports = InputHelper._process_port(arguments.realport) if \
|
||||||
real_ports = InputHelper._process_port(arguments.realport)
|
arguments.realport else None
|
||||||
|
|
||||||
# process targets first
|
str_targets, ipset_targets = InputHelper._process_targets(
|
||||||
if arguments.target:
|
arguments=arguments,
|
||||||
ranges.add(arguments.target)
|
)
|
||||||
else:
|
targets_count = len(str_targets) + ipset_targets.size
|
||||||
target_file = arguments.target_list
|
|
||||||
if not isinstance(target_file, TextIOWrapper):
|
|
||||||
if not sys.stdin.isatty():
|
|
||||||
target_file = sys.stdin
|
|
||||||
ranges.update([target.strip() for target in target_file if target.strip()])
|
|
||||||
|
|
||||||
# process exclusions first
|
if not targets_count:
|
||||||
if arguments.exclusions:
|
|
||||||
exclusions_ranges.add(arguments.exclusions)
|
|
||||||
else:
|
|
||||||
if arguments.exclusions_list:
|
|
||||||
for exclusion in arguments.exclusions_list:
|
|
||||||
exclusion = exclusion.strip()
|
|
||||||
if exclusion:
|
|
||||||
exclusions.add(exclusion)
|
|
||||||
|
|
||||||
# removing elements that may have spaces (helpful for easily processing comma notation)
|
|
||||||
InputHelper._pre_process_hosts(ranges, targets, arguments)
|
|
||||||
InputHelper._pre_process_hosts(exclusions_ranges, exclusions, arguments)
|
|
||||||
|
|
||||||
# difference operation
|
|
||||||
targets -= exclusions
|
|
||||||
|
|
||||||
if len(targets) == 0:
|
|
||||||
raise Exception("No target provided, or empty target list")
|
raise Exception("No target provided, or empty target list")
|
||||||
|
|
||||||
if arguments.random:
|
if arguments.random:
|
||||||
files = InputHelper._get_files_from_directory(arguments.random)
|
files = InputHelper._get_files_from_directory(arguments.random)
|
||||||
random_file = choice(files)
|
random_file = choice(files)
|
||||||
|
|
||||||
if arguments.command:
|
|
||||||
commands.append(Task(arguments.command.rstrip('\n')))
|
|
||||||
else:
|
else:
|
||||||
commands = InputHelper._pre_process_commands(arguments.command_list)
|
random_file = None
|
||||||
|
|
||||||
# commands = InputHelper._replace_variable_with_commands(commands, "_target_", targets)
|
tasks = list()
|
||||||
commands = InputHelper._process_clean_targets(commands, targets)
|
if arguments.command:
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_host_", targets)
|
tasks.append(Task(arguments.command.rstrip('\n')))
|
||||||
|
else:
|
||||||
if arguments.port:
|
tasks = InputHelper._pre_process_commands(arguments.command_list)
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_port_", ports)
|
|
||||||
|
|
||||||
if arguments.realport:
|
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_realport_", real_ports)
|
|
||||||
|
|
||||||
if arguments.random:
|
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_random_", [random_file])
|
|
||||||
|
|
||||||
if arguments.output:
|
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_output_", [arguments.output])
|
|
||||||
|
|
||||||
if arguments.proto:
|
if arguments.proto:
|
||||||
if "," in arguments.proto:
|
protocols = arguments.proto.split(",")
|
||||||
protocols = arguments.proto.split(",")
|
# if "," not in arguments.proto, [arguments.proto] is returned by
|
||||||
else:
|
# .split()
|
||||||
protocols = arguments.proto
|
else:
|
||||||
commands = InputHelper._replace_variable_with_commands(commands, "_proto_", protocols)
|
protocols = None
|
||||||
|
|
||||||
# process proxies
|
# Calculate the tasks count, as we will not have access to the len() of
|
||||||
if arguments.proxy_list:
|
# the tasks iterator
|
||||||
proxy_list = [proxy for proxy in arguments.proxy_list if proxy.strip()]
|
tasks_count = len(tasks) * targets_count
|
||||||
if len(proxy_list) < len(commands):
|
if ports:
|
||||||
proxy_list = ceil(len(commands) / len(proxy_list)) * proxy_list
|
tasks_count *= len(ports)
|
||||||
|
if real_ports:
|
||||||
|
tasks_count *= len(real_ports)
|
||||||
|
if protocols:
|
||||||
|
tasks_count *= len(protocols)
|
||||||
|
|
||||||
InputHelper._replace_variable_array(commands, "_proxy_", proxy_list)
|
return {
|
||||||
return commands
|
"tasks": tasks,
|
||||||
|
"str_targets": str_targets,
|
||||||
|
"ipset_targets": ipset_targets,
|
||||||
|
"ports": ports,
|
||||||
|
"real_ports": real_ports,
|
||||||
|
"random_file": random_file,
|
||||||
|
"output": arguments.output,
|
||||||
|
"protocols": protocols,
|
||||||
|
"proxy_list": arguments.proxy_list,
|
||||||
|
"tasks_count": tasks_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_tasks_generator_func(tasks_data):
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_target_variables_in_commands,
|
||||||
|
tasks=tasks_data["tasks"],
|
||||||
|
str_targets=tasks_data["str_targets"],
|
||||||
|
ipset_targets=tasks_data["ipset_targets"],
|
||||||
|
)
|
||||||
|
|
||||||
|
ports = tasks_data["ports"]
|
||||||
|
if ports:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_port_",
|
||||||
|
replacements=ports,
|
||||||
|
)
|
||||||
|
|
||||||
|
real_ports = tasks_data["real_ports"]
|
||||||
|
if real_ports:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_realport_",
|
||||||
|
replacements=real_ports,
|
||||||
|
)
|
||||||
|
|
||||||
|
random_file = tasks_data["random_file"]
|
||||||
|
if random_file:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_random_",
|
||||||
|
replacements=[random_file],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = tasks_data["output"]
|
||||||
|
if output:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_output_",
|
||||||
|
replacements=[output],
|
||||||
|
)
|
||||||
|
|
||||||
|
protocols = tasks_data["protocols"]
|
||||||
|
if protocols:
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_in_commands,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_proto_",
|
||||||
|
replacements=protocols,
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_list = tasks_data["proxy_list"]
|
||||||
|
if proxy_list:
|
||||||
|
proxy_list_iterator = itertools.cycle(
|
||||||
|
proxy for proxy in (
|
||||||
|
proxy.strip() for proxy in proxy_list
|
||||||
|
) if proxy
|
||||||
|
)
|
||||||
|
tasks_generator_func = functools.partial(
|
||||||
|
InputHelper._replace_variable_array,
|
||||||
|
tasks_generator_func=tasks_generator_func,
|
||||||
|
variable="_proxy_",
|
||||||
|
replacements_iterator=proxy_list_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tasks_generator_func
|
||||||
|
|
||||||
|
|
||||||
class InputParser(object):
|
class InputParser(object):
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class Worker(object):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# get task from queue
|
# get task from queue
|
||||||
task = self.queue.pop(0)
|
task = next(self.queue)
|
||||||
if isinstance(self.tqdm, tqdm):
|
if isinstance(self.tqdm, tqdm):
|
||||||
self.tqdm.update(1)
|
self.tqdm.update(1)
|
||||||
# run task
|
# run task
|
||||||
@@ -90,17 +90,19 @@ class Pool(object):
|
|||||||
if max_workers <= 0:
|
if max_workers <= 0:
|
||||||
raise ValueError("Workers must be >= 1")
|
raise ValueError("Workers must be >= 1")
|
||||||
|
|
||||||
|
tasks_count = next(task_queue)
|
||||||
|
|
||||||
# check if the queue is empty
|
# check if the queue is empty
|
||||||
if not task_queue:
|
if not tasks_count:
|
||||||
raise ValueError("The queue is empty")
|
raise ValueError("The queue is empty")
|
||||||
|
|
||||||
self.queue = task_queue
|
self.queue = task_queue
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.output = output
|
self.output = output
|
||||||
self.max_workers = min(len(task_queue), max_workers)
|
self.max_workers = min(tasks_count, max_workers)
|
||||||
|
|
||||||
if not progress_bar:
|
if not progress_bar:
|
||||||
self.tqdm = tqdm(total=len(task_queue))
|
self.tqdm = tqdm(total=tasks_count)
|
||||||
else:
|
else:
|
||||||
self.tqdm = True
|
self.tqdm = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user