mirror of
https://github.com/aljazceru/azure-firewall-updater.git
synced 2025-12-16 20:25:08 +01:00
185 lines
7.5 KiB
Python
185 lines
7.5 KiB
Python
import os
|
|
import json
|
|
from datetime import datetime
|
|
from azure.identity import AzureCliCredential
|
|
from azure.mgmt.network import NetworkManagementClient
|
|
from azure.mgmt.compute import ComputeManagementClient
|
|
from azure.mgmt.subscription import SubscriptionClient
|
|
import requests
|
|
|
|
def get_azure_clients():
|
|
"""Create and return Azure clients using Azure CLI credentials."""
|
|
try:
|
|
credential = AzureCliCredential()
|
|
|
|
# Get the subscription ID
|
|
subscription_client = SubscriptionClient(credential)
|
|
subscriptions = list(subscription_client.subscriptions.list())
|
|
if not subscriptions:
|
|
raise ValueError("No subscriptions found. Please check your Azure CLI login.")
|
|
subscription_id = subscriptions[0].subscription_id
|
|
|
|
network_client = NetworkManagementClient(credential, subscription_id)
|
|
compute_client = ComputeManagementClient(credential, subscription_id)
|
|
return network_client, compute_client
|
|
except Exception as e:
|
|
print(f"Error setting up Azure clients: {str(e)}")
|
|
print("Please ensure you're logged in with Azure CLI (run 'az login')")
|
|
exit(1)
|
|
|
|
def get_resource_group_from_id(resource_id):
|
|
"""Extract resource group name from a resource ID."""
|
|
parts = resource_id.split('/')
|
|
return parts[parts.index('resourceGroups') + 1]
|
|
|
|
def get_current_ip():
|
|
"""Get the current outgoing IP address."""
|
|
return requests.get('https://api.ipify.org').text
|
|
|
|
def backup_nsg(nsg, backup_dir):
|
|
"""Create a backup of the network security group."""
|
|
if not os.path.exists(backup_dir):
|
|
os.makedirs(backup_dir)
|
|
backup_file = os.path.join(backup_dir, f"{nsg.name}_{datetime.now().strftime('%Y%m%d%H%M%S')}.json")
|
|
with open(backup_file, 'w') as f:
|
|
json.dump(nsg.as_dict(), f)
|
|
return backup_file
|
|
|
|
def update_nsg_rule(network_client, nsg, ip_address, dry_run=False):
|
|
"""Update the NSG rule to allow SSH access from the specified IP."""
|
|
ssh_rule = next((rule for rule in nsg.security_rules if rule.destination_port_range == '22'), None)
|
|
|
|
if ssh_rule:
|
|
if ip_address not in ssh_rule.source_address_prefixes:
|
|
ssh_rule.source_address_prefixes.append(ip_address)
|
|
else:
|
|
ssh_rule = {
|
|
'name': 'AllowSSH',
|
|
'protocol': 'Tcp',
|
|
'source_port_range': '*',
|
|
'destination_port_range': '22',
|
|
'source_address_prefixes': [ip_address],
|
|
'destination_address_prefix': '*',
|
|
'access': 'Allow',
|
|
'priority': 1000,
|
|
'direction': 'Inbound'
|
|
}
|
|
nsg.security_rules.append(ssh_rule)
|
|
|
|
if not dry_run:
|
|
return network_client.network_security_groups.begin_create_or_update(
|
|
get_resource_group_from_id(nsg.id), nsg.name, nsg
|
|
)
|
|
return None
|
|
|
|
def log_ip(ip_address, log_file):
|
|
"""Log the IP address with timestamp."""
|
|
with open(log_file, 'a') as f:
|
|
f.write(f"{datetime.now().isoformat()},{ip_address}\n")
|
|
|
|
def list_vms():
|
|
"""List all virtual machines in the subscription."""
|
|
_, compute_client = get_azure_clients()
|
|
vms = compute_client.virtual_machines.list_all()
|
|
print("Available Virtual Machines:")
|
|
for vm in vms:
|
|
print(f"- {vm.name} (Resource Group: {get_resource_group_from_id(vm.id)})")
|
|
|
|
def firewall_dump(vm_name=None):
|
|
"""Iterate through virtual machines and print their firewall rules."""
|
|
network_client, compute_client = get_azure_clients()
|
|
|
|
if vm_name:
|
|
vms = [vm for vm in compute_client.virtual_machines.list_all() if vm.name == vm_name]
|
|
if not vms:
|
|
print(f"No VM found with name: {vm_name}")
|
|
return
|
|
else:
|
|
vms = compute_client.virtual_machines.list_all()
|
|
|
|
for vm in vms:
|
|
print(f"\nFirewall rules for VM: {vm.name}")
|
|
print("=" * 50)
|
|
|
|
resource_group_name = get_resource_group_from_id(vm.id)
|
|
|
|
for nic_ref in vm.network_profile.network_interfaces:
|
|
nic_name = nic_ref.id.split('/')[-1]
|
|
try:
|
|
nic = network_client.network_interfaces.get(resource_group_name, nic_name)
|
|
|
|
if nic.network_security_group:
|
|
nsg_id = nic.network_security_group.id
|
|
nsg_name = nsg_id.split('/')[-1]
|
|
nsg = network_client.network_security_groups.get(resource_group_name, nsg_name)
|
|
|
|
for rule in nsg.security_rules:
|
|
print(f"Rule: {rule.name}")
|
|
print(f" Direction: {rule.direction}")
|
|
print(f" Priority: {rule.priority}")
|
|
print(f" Protocol: {rule.protocol}")
|
|
print(f" Source Port Range: {rule.source_port_range}")
|
|
print(f" Destination Port Range: {rule.destination_port_range}")
|
|
print(f" Source Address Prefix: {rule.source_address_prefix}")
|
|
print(f" Destination Address Prefix: {rule.destination_address_prefix}")
|
|
print(f" Access: {rule.access}")
|
|
print("-" * 40)
|
|
else:
|
|
print(f"No Network Security Group associated with NIC: {nic_name}")
|
|
except Exception as e:
|
|
print(f"Error processing NIC {nic_name}: {str(e)}")
|
|
|
|
def main(dry_run=False, vm_name=None):
|
|
network_client, compute_client = get_azure_clients()
|
|
current_ip = get_current_ip()
|
|
backup_dir = 'nsg_backups'
|
|
log_file = 'ip_log.csv'
|
|
|
|
log_ip(current_ip, log_file)
|
|
|
|
if vm_name:
|
|
vms = [vm for vm in compute_client.virtual_machines.list_all() if vm.name == vm_name]
|
|
if not vms:
|
|
print(f"No VM found with name: {vm_name}")
|
|
return
|
|
else:
|
|
vms = compute_client.virtual_machines.list_all()
|
|
|
|
for vm in vms:
|
|
resource_group_name = get_resource_group_from_id(vm.id)
|
|
for nic_ref in vm.network_profile.network_interfaces:
|
|
nic_name = nic_ref.id.split('/')[-1]
|
|
nic = network_client.network_interfaces.get(resource_group_name, nic_name)
|
|
|
|
if nic.network_security_group:
|
|
nsg_id = nic.network_security_group.id
|
|
nsg_name = nsg_id.split('/')[-1]
|
|
nsg = network_client.network_security_groups.get(resource_group_name, nsg_name)
|
|
|
|
backup_file = backup_nsg(nsg, backup_dir)
|
|
print(f"Backed up NSG {nsg.name} to {backup_file}")
|
|
|
|
operation = update_nsg_rule(network_client, nsg, current_ip, dry_run)
|
|
if dry_run:
|
|
print(f"Dry run: Would update NSG {nsg.name} to allow SSH from {current_ip}")
|
|
else:
|
|
operation.wait()
|
|
print(f"Updated NSG {nsg.name} to allow SSH from {current_ip}")
|
|
else:
|
|
print(f"No Network Security Group associated with NIC: {nic_name}")
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Update Azure VM firewalls for SSH access.")
|
|
parser.add_argument("--dry-run", action="store_true", help="Perform a dry run without making changes.")
|
|
parser.add_argument("--dump", action="store_true", help="Dump firewall rules for all VMs.")
|
|
parser.add_argument("--list", action="store_true", help="List all available VMs.")
|
|
parser.add_argument("--vm", help="Specify a VM name to update or dump firewall rules for.")
|
|
args = parser.parse_args()
|
|
|
|
if args.list:
|
|
list_vms()
|
|
elif args.dump:
|
|
firewall_dump(args.vm)
|
|
else:
|
|
main(dry_run=args.dry_run, vm_name=args.vm) |