Skip to content
Snippets Groups Projects
Commit b0f5bb46 authored by aquadzn's avatar aquadzn Committed by Mathilde Caron
Browse files

Refactoring. Much cleaner

parent a8e94db1
Branches
No related tags found
No related merge requests found
......@@ -154,25 +154,26 @@ https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-
Extract frames from input video and generate attention video:
```
python video_generation.py --input_path ../video.mp4 \
--output_dir ../output/ \
--resize 256 \
python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
--input_path input/video.mp4 \
--output_path output/ \
--fps 25
```
Use folder of frames already extracted and attention video:
Use folder of frames already extracted and generate attention video:
```
python video_generation.py --input_path ../frames/ \
--output_dir ../output/ \
--resize 720 1280 \
--video_format avi
python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
--input_path output/frames/ \
--output_path output/ \
--resize 256 \
```
Only generate video from folder of attention maps images:
```
python video_generation.py --output_dir ../output/ \
--resize 256 \
--fps 60 \
--video_only
python video_generation.py --input_path output/attention \
--output_path output/ \
--video_only \
--video_format avi
```
Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for a video inference notebook.
......
......@@ -18,38 +18,116 @@ import utils
import vision_transformer as vits
def extract_frames_from_video():
vidcap = cv2.VideoCapture(args.input_path)
args.fps = vidcap.get(cv2.CAP_PROP_FPS)
print(f"Video: {args.input_path} ({args.fps} fps)")
print("Extracting frames...")
FOURCC = {
"mp4": cv2.VideoWriter_fourcc(*"MP4V"),
"avi": cv2.VideoWriter_fourcc(*"XVID"),
}
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
class VideoGenerator:
def __init__(self, args):
self.args = args
# self.model = None
# Don't need to load model if you only want a video
if not self.args.video_only:
self.model = self.__load_model()
def run(self):
if self.args.input_path is None:
print(f"Provided input path {self.args.input_path} is non valid.")
sys.exit(1)
else:
if self.args.video_only:
self._generate_video_from_images(
self.args.input_path, self.args.output_path
)
else:
# If input path exists
if os.path.exists(self.args.input_path):
# If input is a video file
if os.path.isfile(self.args.input_path):
frames_folder = os.path.join(self.args.output_path, "frames")
attention_folder = os.path.join(
self.args.output_path, "attention"
)
os.makedirs(frames_folder, exist_ok=True)
os.makedirs(attention_folder, exist_ok=True)
self._extract_frames_from_video(
self.args.input_path, frames_folder
)
self._inference(
frames_folder,
attention_folder,
)
self._generate_video_from_images(
attention_folder, self.args.output_path
)
# If input is a folder of already extracted frames
if os.path.isdir(self.args.input_path):
attention_folder = os.path.join(
self.args.output_path, "attention"
)
os.makedirs(attention_folder, exist_ok=True)
self._inference(self.args.input_path, attention_folder)
self._generate_video_from_images(
attention_folder, self.args.output_path
)
# If input path doesn't exists
else:
print(f"Provided input path {self.args.input_path} doesn't exists.")
sys.exit(1)
def _extract_frames_from_video(self, inp: str, out: str):
vidcap = cv2.VideoCapture(inp)
self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)
print(f"Video: {inp} ({self.args.fps} fps)")
print(f"Extracting frames to {out}")
success, image = vidcap.read()
count = 0
while success:
cv2.imwrite(os.path.join(args.output_dir, f"frame-{count:04}.jpg"), image)
cv2.imwrite(
os.path.join(out, f"frame-{count:04}.jpg"),
image,
)
success, image = vidcap.read()
count += 1
def generate_video_from_images(format="mp4"):
print("Generating video...")
def _generate_video_from_images(self, inp: str, out: str):
img_array = []
# Change format to png if needed
for filename in tqdm(sorted(glob.glob(os.path.join(args.output_dir, "attn-*.jpg")))):
with open(filename, "rb") as f:
attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))
# Get size of the first image
with open(attention_images_list[0], "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
size = (img.width, img.height)
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
print(f"Generating video {size} to {out}")
for filename in tqdm(attention_images_list[1:]):
with open(filename, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
if args.video_format == "avi":
out = cv2.VideoWriter(
"video.avi", cv2.VideoWriter_fourcc(*"XVID"), args.fps, size
)
else:
out = cv2.VideoWriter(
"video.mp4", cv2.VideoWriter_fourcc(*"MP4V"), args.fps, size
os.path.join(out, "video." + self.args.video_format),
FOURCC[self.args.video_format],
self.args.fps,
size,
)
for i in range(len(img_array)):
......@@ -57,18 +135,19 @@ def generate_video_from_images(format="mp4"):
out.release()
print("Done")
def _inference(self, inp: str, out: str):
print(f"Generating attention images to {out}")
def inference(images_folder_list: str):
for img_path in tqdm(images_folder_list):
for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
with open(img_path, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
if args.resize is not None:
if self.args.resize is not None:
transform = pth_transforms.Compose(
[
pth_transforms.ToTensor(),
pth_transforms.Resize(args.resize),
pth_transforms.Resize(self.args.resize),
pth_transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
),
......@@ -88,15 +167,15 @@ def inference(images_folder_list: str):
# make the image divisible by the patch size
w, h = (
img.shape[1] - img.shape[1] % args.patch_size,
img.shape[2] - img.shape[2] % args.patch_size,
img.shape[1] - img.shape[1] % self.args.patch_size,
img.shape[2] - img.shape[2] % self.args.patch_size,
)
img = img[:, :w, :h].unsqueeze(0)
w_featmap = img.shape[-2] // args.patch_size
h_featmap = img.shape[-1] // args.patch_size
w_featmap = img.shape[-2] // self.args.patch_size
h_featmap = img.shape[-1] // self.args.patch_size
attentions = model.forward_selfattention(img.to(device))
attentions = self.model.forward_selfattention(img.to(DEVICE))
nh = attentions.shape[1] # number of head
......@@ -107,7 +186,7 @@ def inference(images_folder_list: str):
val, idx = torch.sort(attentions)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - args.threshold)
th_attn = cumval > (1 - self.args.threshold)
idx2 = torch.argsort(idx)
for head in range(nh):
th_attn[head] = th_attn[head][idx2[head]]
......@@ -115,7 +194,9 @@ def inference(images_folder_list: str):
# interpolate
th_attn = (
nn.functional.interpolate(
th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
th_attn.unsqueeze(0),
scale_factor=self.args.patch_size,
mode="nearest",
)[0]
.cpu()
.numpy()
......@@ -124,15 +205,16 @@ def inference(images_folder_list: str):
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = (
nn.functional.interpolate(
attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
attentions.unsqueeze(0),
scale_factor=self.args.patch_size,
mode="nearest",
)[0]
.cpu()
.numpy()
)
# save attentions heatmaps
os.makedirs(args.output_dir, exist_ok=True)
fname = os.path.join(args.output_dir, "attn-" + os.path.basename(img_path))
fname = os.path.join(out, "attn-" + os.path.basename(img_path))
plt.imsave(
fname=fname,
arr=sum(
......@@ -143,26 +225,31 @@ def inference(images_folder_list: str):
format="jpg",
)
generate_video_from_images(args.video_format)
def load_model():
def __load_model(self):
# build model
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
model = vits.__dict__[self.args.arch](
patch_size=self.args.patch_size, num_classes=0
)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.to(device)
if os.path.isfile(args.pretrained_weights):
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
state_dict = state_dict[args.checkpoint_key]
model.to(DEVICE)
if os.path.isfile(self.args.pretrained_weights):
state_dict = torch.load(self.args.pretrained_weights, map_location="cpu")
if (
self.args.checkpoint_key is not None
and self.args.checkpoint_key in state_dict
):
print(
f"Take key {self.args.checkpoint_key} in provided checkpoint dict"
)
state_dict = state_dict[self.args.checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
print(
"Pretrained weights found at {} and loaded with msg: {}".format(
args.pretrained_weights, msg
self.args.pretrained_weights, msg
)
)
else:
......@@ -170,13 +257,13 @@ def load_model():
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url = None
if args.arch == "deit_small" and args.patch_size == 16:
if self.args.arch == "deit_small" and self.args.patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif args.arch == "deit_small" and args.patch_size == 8:
elif self.args.arch == "deit_small" and self.args.patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
elif args.arch == "vit_base" and args.patch_size == 16:
elif self.args.arch == "vit_base" and self.args.patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif args.arch == "vit_base" and args.patch_size == 8:
elif self.args.arch == "vit_base" and self.args.patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if url is not None:
print(
......@@ -194,7 +281,7 @@ def load_model():
def parse_args():
parser = argparse.ArgumentParser("Visualize Self-Attention maps")
parser = argparse.ArgumentParser("Generation self-attention video")
parser.add_argument(
"--arch",
default="deit_small",
......@@ -203,7 +290,7 @@ def parse_args():
help="Architecture (support only ViT atm).",
)
parser.add_argument(
"--patch_size", default=8, type=int, help="Patch resolution of the model."
"--patch_size", default=8, type=int, help="Patch resolution of the self.model."
)
parser.add_argument(
"--pretrained_weights",
......@@ -219,16 +306,18 @@ def parse_args():
)
parser.add_argument(
"--input_path",
default=None,
required=True,
type=str,
help="""Path to a video file if you want to extract frames
or to a folder of images already extracted by yourself.""",
or to a folder of images already extracted by yourself.
or to a folder of attention images.""",
)
parser.add_argument(
"--output_dir",
required=True,
"--output_path",
default="./",
type=str,
help="Path where to save visualizations and / or video.",
help="""Path to store a folder of frames and / or a folder of attention images.
and / or a final video. Default to current directory.""",
)
parser.add_argument(
"--threshold",
......@@ -245,18 +334,18 @@ def parse_args():
help="""Apply a resize transformation to input image(s). Use if OOM error.
Usage (single or W H): --resize 512, --resize 720 1280""",
)
parser.add_argument(
"--video_only",
action="store_true",
help="""Use this flag if you only want to generate a video and not all attention images.
If used, --input_path must be set to the folder of attention images. Ex: ./attention/""",
)
parser.add_argument(
"--fps",
default=30.0,
type=float,
help="FPS of input / output video. Automatically set if you extract frames from a video.",
)
parser.add_argument(
"--video_only",
action="store_true",
help="""Use this flag if you only want to generate a video and not all attention images.
If used, --output_dir must be set to the folder containing attention images.""",
)
parser.add_argument(
"--video_format",
default="mp4",
......@@ -270,36 +359,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = load_model()
# If you only want a video
if args.video_only:
generate_video_from_images(args.video_format)
else:
# If input path isn't set
if args.input_path is None:
print(f"Provided input path {args.input_path} is non valid.")
sys.exit(1)
else:
# If input path exists
if os.path.exists(args.input_path):
# If input is a video file
if os.path.isfile(args.input_path):
extract_frames_from_video()
imgs_list = [
os.path.join(args.output_dir, i)
for i in sorted(os.listdir(args.output_dir))
]
inference(imgs_list)
# If input is an images folder
if os.path.isdir(args.input_path):
imgs_list = [
os.path.join(args.input_path, i)
for i in sorted(os.listdir(args.input_path))
]
inference(imgs_list)
# If input path doesn't exists
else:
print(f"Provided video file path {args.input_path} is non valid.")
sys.exit(1)
vg = VideoGenerator(args)
vg.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment