diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..6daf17e --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: Continuous integration + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + tests: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install + run: | + python3 -m venv .env + source .env/bin/activate + make install + - name: Tests + run: | + source .env/bin/activate + make test + diff --git a/.gitignore b/.gitignore index 41f11cf..1e41427 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,5 @@ dmypy.json # Pyre type checker .pyre/ +.tracker_data +*.pth diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5ce5220 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +install: + pip install -U pip + pip install -e . + +test: + CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json diff --git a/configs/train_decoder_config.test.json b/configs/train_decoder_config.test.json new file mode 100644 index 0000000..d9aa957 --- /dev/null +++ b/configs/train_decoder_config.test.json @@ -0,0 +1,102 @@ +{ + "decoder": { + "unets": [ + { + "dim": 16, + "image_embed_dim": 768, + "cond_dim": 16, + "channels": 3, + "dim_mults": [1, 2, 4, 8], + "attn_dim_head": 16, + "attn_heads": 4, + "self_attn": [false, true, true, true] + } + ], + "clip": { + "make": "openai", + "model": "ViT-L/14" + }, + + "timesteps": 10, + "image_sizes": [64], + "channels": 3, + "loss_type": "l2", + "beta_schedule": ["cosine"], + "learned_variance": true + }, + "data": { + "webdataset_base_url": "test_data/{}.tar", + "num_workers": 4, + "batch_size": 4, + "start_shard": 0, + "end_shard": 9, + "shard_width": 1, + "index_width": 1, + "splits": { + "train": 0.75, + "val": 0.15, + "test": 0.1 + }, + "shuffle_train": false, + "resample_train": true, + "preprocessing": { + "RandomResizedCrop": { + "size": [64, 64], + "scale": [0.75, 1.0], + "ratio": [1.0, 1.0] + }, + "ToTensor": true + } + }, + "train": { + "epochs": 1, + "lr": 1e-16, + "wd": 0.01, + "max_grad_norm": 0.5, + "save_every_n_samples": 100, + "n_sample_images": 1, + "device": "cpu", + "epoch_samples": 50, + "validation_samples": 5, + "use_ema": true, + "ema_beta": 0.99, + "amp": false, + "save_all": false, + "save_latest": true, + "save_best": true, + "unet_training_mask": [true] + }, + "evaluate": { + "n_evaluation_samples": 2, + "FID": { + "feature": 64 + }, + "IS": { + "feature": 64, + "splits": 10 + }, + "KID": { + "feature": 64, + "subset_size": 2 + }, + "LPIPS": { + "net_type": "vgg", + "reduction": "mean" + } + }, + "tracker": { + "overwrite_data_path": true, + + "log": { + "log_type": "console" + }, + + "load": { + "load_from": null + }, + + "save": [{ + "save_to": "local" + }] + } +} diff --git a/test_data/0.tar b/test_data/0.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/0.tar differ diff --git a/test_data/1.tar b/test_data/1.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/1.tar differ diff --git a/test_data/2.tar b/test_data/2.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/2.tar differ diff --git a/test_data/3.tar b/test_data/3.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/3.tar differ diff --git a/test_data/4.tar b/test_data/4.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/4.tar differ diff --git a/test_data/5.tar b/test_data/5.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/5.tar differ diff --git a/test_data/6.tar b/test_data/6.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/6.tar differ diff --git a/test_data/7.tar b/test_data/7.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/7.tar differ diff --git a/test_data/8.tar b/test_data/8.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/8.tar differ diff --git a/test_data/9.tar b/test_data/9.tar new file mode 100644 index 0000000..91ad1ea Binary files /dev/null and b/test_data/9.tar differ