From e5d154578c67891f6f6b4e457859774ad45c0848 Mon Sep 17 00:00:00 2001 From: Markus Kreitzer Date: Tue, 16 Jan 2024 16:54:03 -0600 Subject: [PATCH] Some improvements to the interface. --- poetry.lock | 2 +- pyproject.toml | 1 + src/video_upscaler/main.py | 62 ++++++++++++++++++++++++++---------- tests/data/DSC_0141.jpeg | Bin 0 -> 6574 bytes tests/test_image_upscale.py | 3 ++ 5 files changed, 50 insertions(+), 18 deletions(-) create mode 100644 tests/data/DSC_0141.jpeg create mode 100644 tests/test_image_upscale.py diff --git a/poetry.lock b/poetry.lock index 16815cc..ca2b26a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1134,4 +1134,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "fc966464d3ffb01227a2e78352820e32055828c4b685e01e199bb663b4f0a1ba" +content-hash = "9506eb71838bf8c0d0acfe2bda9c32ee7706ef504e1c6f020537754840f04177" diff --git a/pyproject.toml b/pyproject.toml index 9ecdef2..d162212 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ torch = "^2.1.2" realesrgan = {git = "https://github.com/sberbank-ai/Real-ESRGAN.git"} torchvision = "^0.16.2" torchaudio = "^2.1.2" +tqdm = "^4.66.1" [tool.poetry.scripts] upscale = "video_upscaler.main:main" diff --git a/src/video_upscaler/main.py b/src/video_upscaler/main.py index 89895d9..7e62b5a 100644 --- a/src/video_upscaler/main.py +++ b/src/video_upscaler/main.py @@ -6,6 +6,13 @@ from RealESRGAN import RealESRGAN from argparse import ArgumentParser from pathlib import Path +from tqdm import tqdm + +def upscale_image(image: Image, device: torch.device, weights: Path, scale: int = 4) -> Image: + """Upscale a single image.""" + upscaler = get_upscaler(device, scale) + upscaler.load_weights(weights, download=True) + return upscaler.predict(image) def upscale_frame(frame: av.VideoFrame, upscaler: RealESRGAN) -> av.VideoFrame: """Upscale a single frame.""" @@ -13,22 +20,17 @@ def upscale_frame(frame: av.VideoFrame, upscaler: RealESRGAN) -> av.VideoFrame: sr_img = upscaler.predict(img) return av.VideoFrame.from_image(sr_img) -def upscale_video(input_path: Path, output_path: Path, scale: int = 4): - """Upscale video frames and reassemble the video.""" - if torch.cuda.is_available(): - device = torch.device('cuda') - print("Using CUDA.") - elif torch.backends.mps.is_available(): - device = torch.device('mps') - print("Using MPS.") - else: - device = torch.device('cpu') - print("Using CPU.") +def get_upscaler(device: torch.device, model_weights_path: Path, scale: int = 4) -> RealESRGAN: + """Get the upscaler model.""" upscaler = RealESRGAN(device, scale) - upscaler.load_weights('weights/RealESRGAN_x4.pth', download=True) + upscaler.load_weights(str(model_weights_path), download=True) + return upscaler + +def upscale_video(input_path: Path, upscaler: RealESRGAN, scale: int = 4) -> None: + """Upscale video frames and reassemble the video.""" input_container = av.open(str(input_path)) - output_container = av.open(str(output_path), 'w') + output_container = av.open(f"{input_path.stem}_HD.mp4", 'w') stream = input_container.streams.video[0] output_stream = output_container.add_stream('mpeg4', rate=stream.average_rate) @@ -36,25 +38,51 @@ def upscale_video(input_path: Path, output_path: Path, scale: int = 4): output_stream.height = stream.height * scale output_stream.pix_fmt = 'yuv420p' + total_frames = input_container.streams.video[0].frames + + print("Upscaling video...") + progress_bar = tqdm(total=total_frames, unit='frames', desc='Upscaling') + for frame in input_container.decode(stream): sr_frame = upscale_frame(frame, upscaler) packet = output_stream.encode(sr_frame) output_container.mux(packet) + progress_bar.update() + + progress_bar.close() - # Flush and close the containers + print("Flush and close the containers") + progress_bar = tqdm(total=total_frames, unit='frames', desc='Encoding/Muxing') for packet in output_stream.encode(): output_container.mux(packet) + progress_bar.update() + + progress_bar.close() input_container.close() output_container.close() +def get_device() -> torch.device: + if torch.cuda.is_available(): + print("Using CUDA.") + return torch.device('cuda') + elif torch.backends.mps.is_available(): + print("Using MPS.") + return torch.device('mps') + else: + print("Using CPU.") + return torch.device('cpu') + def main(): parser = ArgumentParser(description="Upscale video files.") parser.add_argument('input_file', type=Path, help='Path to the input video file.') - parser.add_argument('output_file', type=Path, help='Path for the output upscaled video file.') + parser.add_argument('--scale', type=int, default=4, help='Upscaling factor.') + parser.add_argument('--model_weights', type=Path, default=Path('weights/RealESRGAN_x4.pth'), help='Path to the model weights file.') + parser.add_argument('--download-weights', action='store_true', default=True, help='Download the model weights if they are not found locally.') args = parser.parse_args() - - upscale_video(args.input_file, args.output_file) + device = get_device() + upscaler = get_upscaler(device, args.model_weights, args.scale) + upscale_video(args.input_file, upscaler, args.scale) if __name__ == "__main__": main() diff --git a/tests/data/DSC_0141.jpeg b/tests/data/DSC_0141.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..8dfdac4608e84610259b7078b054c3f5e6a33341 GIT binary patch literal 6574 zcmeHKcT`i^_C7a-02T~If=0*$6;L4!kVFEZDm8SF!I40Mw2+X16hWGmhg2&M5hX}b z7*Vk@;-CT+1Tg|C78KO6;V5OWOx{f>y55?%-n`$RzvWKOy{bz%yZ-1E~Wfe1_l; zfaU?T0t2uYnm>)pAj3%x908&z?p1|(RS;EyHDP>wYXm?|-rEA7fxD9n+nr#?_VQqR z+4wlI-6uWvGzyt4Cj`R-o^#^t#}k=QOvn@hppq%e8DtuRYEFQU6e@#81Hdq|cq$vj z(ainR_~bU?*X*|$OkgM<0Qif6W24AaDkUE|l_jDvA2p3T^CzkR^AH31ScPAXU+1e& z^HvEW}La1J|)oc@o3KLX~f@>NH=oaEl`T zGB^&tFkdk*icBYi!h&yb;a6NXr3)@%88jIa&Y}!4UH$Yj-j|GDysdm;JX zT|S{p-Z4Nj6-SN{Ut>y;lUVLYP3OEyf}Zrdz<#A~I8)l;pptL2OTN)AQT>KL^DBm@ zkefUmDm+x+f&Qyb_$!|=*_0t~seeDiF{U~ahKktA^0@Gj&jHzZ+c`PGfbcg=&X*5a zzv`FpxHut$M2ZocaJYgHo=K>HPfFwnNpSy?fTeY!kP{lkiz9^a);>HX#$?yG7hEhL^qL1e%FJw2nlf z)9EBKl|-c)!wh3_QcN5t(Ktq|KNZ2ACk_=w2;(9IF$8%;PKY2rj!A@PR;0v`wopp(o9@c2q#5^Wp_ z?h#Q@0ujvM^9gcY7WQ8WPWDWqg&B3ZxtS?kPAZ*jMm48W$P)x_o+yG7J(ZW4IhD5D zf-aXSWY~##oH&8VM<9r15~tsF-(2I_bK-bRqP)G0$uwiCxzE(g&jfxQOc31!+=#HG z|3aLu7eh`|%*%tD?>_}5!QDfeY#%wC?nA4;@lbrHe6PUw3Vg4?_X>Qk!1oILf2+XP zl`Jm?UO^r^0n0}?%^7HkBX2^m<}I=P|b%WMFw@lYT@&R?_=4{tUQ0D(tX10^zz zLZ#902rHm0UxScoRC7FH`L{nTr$`gA7Aa!QiEH@s8-0gC>RmMU_lp$|5VDH^s|<2L zIuarR1QtSKA=zc12QSIdkirzV!3TJ!hC(Z0lvPw`z=Q^Pj|U-<2ow^HmZu4&%6D=o zEL!~sij9(nCkI1})ud(=98}h~t^ZZa>(;x4v=C9Iit1c#oX)&Oiwz8omYABE!wWfv zoxOwO%2iIz-afv5{sF54L%F=Lb>R__;<)(r35iL`S(`R**}83ec45)ZUAy<}-B*05 zTvBnkvZ}h~_=%GZr%s2O1#ihFWWmP?W&frN-mgLk6bgyL$aO)81i5f53jG5`N!`X1!->@(QZtk_ zZ3_<8|Ei)-^LnQhBD$qIccE$jqTl7x6w3adu+0CHvT0#ob=?PcKto{!4Gn@Z{JsCQYcYEV&g;n+y&DY2$E<^cS)8^;!GEeZ%+t&n9~S?+le?_rOdf z1UC~B3#>t-rnBvj;|W?WT`bh>v+*z|N1_^x!}j1a&z?!!96RY)k@FkJWD54=bwBWT2Qpw z1XHM%QrvsLVL;gXkdiGKtv$K&Onk?4*4(2uhH9MRPJjFT*YP|DuZ&selDJ#wH&-64 zmgwvq8+QOfr=C$p%;p=>cCB@;4{th)d+lc5w4 z_<>+4)dep+QIGR1%8hg^HmbERD8Zkn*>*(FCL0X=)Z%e)=RB^kQT!z6^Ytrnub<+MwzE<)ZH7%jfI7b= z)orE4l_w$lE02#a+;B@Q1A3GKmX5amEYc@x8pZQRyt-Qa5&v>^S4_>Z8X>nM&m$a> z+dSuVaO7tjtwTck1D}#(>WeLAbxZv788cJLvRHbuTZ0#*f3%Dggk2R zO9-v7nN@Pa!q>Ya#REC3WvG#s?HPA{OhkDR>3D2UEH*Rt)fKbo3R3;CN}aOV$8QwG zWL3AYrQIbl1AbNq7P`dDe63`mJCZ~?=)Fl8j&#z=sCvGPZ8qa4{`@@Wns~!Cqq=Tj z)%eZGUvAAp8XZ||AG_;GO;GBM(L87ALXd?6KXh*Y@6He(bg@ z|KZoNSvsU*zp!AdV9cOLt5vhIes;M72n<=U-CzhWE#>&$o1bg#T;6UB(qvtSdV$fK z?BAq71KA)J1KdWfF12$9oI!Q<<7l(?YZZq-XnyIm7^%zi)+6>Xwkr`4g+yB$X26{jmWyq6x2VCr8bk2QVOii_GP<3*_#Nt) z7;%1b_sKnTcHexgWRyARuY+H1$HHEU(46s}o9Z62W6`J9gCa9fh zgxPm;2#ma>N&T5AHT9L}HP)C7*JwwU6bd6VC1P|KcwBlC>TtljZY2z0dpe7I7(rBZ z@1?oT2lqKX*WDEQp|)z9A-gM`&fVlh*j9pwEl$}zOGl4PDDl3%z4cSeNdvr*3`Bci zWVdfhd;hUXS!kY_#(24N!8WHiLGjnNOVl(FSf!V*_ayAn>-0b~S8aT=6n!hZ%?1bj zS62Vrw)NrJ>h1mn{7hAleP=Yk^6*pFMjq+(`tHCTjtSV_J9)+iMK|)f31upTEwgpk z&yIam;@f>bqbcqvuC%#4dnqe+0C^xZa0cFG1}cZZOuqtjRLg(!Ss0ekTUq43+e;}G zO}%GMRndFJ!nQ2n-1}nZ^DeC>DEYv1uf`v5cCwa2lPon-4;xEG|& z)K--xO)cqqwF#@N<8@QfNuzV{yA1NW&znk*Zdx_Z7l*WxLQx)i=e>u5DtZN$WJ{&n*J}7~(bmETFF#>OYpW|&eFvvRQ5{cL4^B+@Mb4F`t81+XL zO1`|sG-*N>7aZsvxGL@i()(J^U4HaHYg>n9hTQ1c7%@v*UpI`_rj@Fs6uHpX;mOZ| z&D<6D-{y6v8#JWH(sI+b-hWuL_M#t~o{Y^JsZH&D<=TJqb~}h&p7x4;?oQi#{?6R< zQ6XcFE(adk-q%~OV&2v8_Cw2-TkD77hpnAKv!AiW#=#?gKZSzeI3cJp3DB00}_GvBPY}uKUzUrfw zcD}IBzwF-qMRgqZNa*1tv#eU%br-fMZ#5xXjF>_@qWH0iPb@S~Y|QE_y(6`ev{q&; zYrhb^{za4Xz0;@u;od+AtXqZpbYerdDTz5?BUvG#_nm2Z6urSiN$Ya6%brn->Nn5y zgubhz!(8^~D9>s$_$A6Fte|duFhCkVfMm^m(;xBAQq6?geA_CY7fTo4_pYaY**_oG zA24if_oq1UayGCVa(8HVoE@1sPA^m0`83P5RM{ky;(ky%YG$af*%b?yEjc&CRA+Yw z)nHB2W$V`Q9#a>iUw;mGxK!J#i$7#jVC9t7h*>LdlJ1m}M_2zG^y}(a`sYYM(SBQ& z^;uujrpBfr4~%C{X6lRWtgsOJ>j;m}m9N;FxbqpDF(#vnxwvB%D?Yo`e0qMRgZzr4zVx)#vDLtANl7-4^>EvsjN0f7 z5MFsiH1cI2xpw#B_m0fax?*CqW2{MaUy>(F srNh~_&nK(Gc0(u1wWFP--9g-Jy1zP0UhljA0`bmLEC2ui literal 0 HcmV?d00001 diff --git a/tests/test_image_upscale.py b/tests/test_image_upscale.py new file mode 100644 index 0000000..1f93174 --- /dev/null +++ b/tests/test_image_upscale.py @@ -0,0 +1,3 @@ +import pytest +import video_upscaler +