Compare commits

...

3 Commits

Author SHA1 Message Date
Roman Krček
3075743c5d Fix linter issues and unittests
Some checks failed
Build Docker image / build (push) Has been skipped
Build Docker image / test (push) Failing after 2m41s
2024-10-13 21:18:18 +02:00
Roman Krček
6d508121b0 Add hash checking to TT downloading 2024-10-13 21:07:56 +02:00
Roman Krček
47248f10ab Add computed properties to settings 2024-10-13 21:05:51 +02:00
6 changed files with 93 additions and 29 deletions

View File

@@ -6,4 +6,5 @@ tiktok_downloader==0.3.5
uvloop==0.19.0 uvloop==0.19.0
tgcrypto==1.2.5 tgcrypto==1.2.5
sentry-sdk==2.15.0 sentry-sdk==2.15.0
pydantic-settings==2.5.2 pydantic-settings==2.5.2
pydantic==2.9.2

View File

@@ -56,7 +56,7 @@ async def message_handler(_, message: Message):
msg = f"Downloading video {i+1}/{len(urls)}..." msg = f"Downloading video {i+1}/{len(urls)}..."
log.info(msg) log.info(msg)
await message.reply_text(msg) await message.reply_text(msg)
utils.download_tt_video(settings.storage, url) utils.download_tt_video(url)
await message.reply_text("Done.") await message.reply_text("Done.")
@@ -67,7 +67,7 @@ async def media_handler(client, message: Message):
await message.reply_text("Downloading media...") await message.reply_text("Downloading media...")
utils.handle_media_message_contents(settings.storage, client, message) utils.handle_media_message_contents(client, message)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -3,19 +3,17 @@ from functools import wraps
from telegram_downloader_bot.logger import log from telegram_downloader_bot.logger import log
from telegram_downloader_bot.settings import settings from telegram_downloader_bot.settings import settings
allowed_ids = settings.allowed_ids.split(",")
allowed_ids = [int(x) for x in allowed_ids]
def protected(func): def protected(func):
@wraps(func) @wraps(func)
async def wrapper(client, message): async def wrapper(client, message):
if int(message.from_user.id) not in allowed_ids: if int(message.from_user.id) not in settings.allowed_ids_list:
log.warning( log.warning(
f"User with ID {message.from_user.id} attempted" f"User with ID {message.from_user.id} attempted"
"to text this bot!") "to text this bot!")
log.info( log.info(
f"Only users allowed are: {' '.join(allowed_ids)}") "Only users allowed are:"
f"{' '.join(settings.allowed_ids_list)}")
return await message.reply_text("You are not on the list!") return await message.reply_text("You are not on the list!")
return await func(client, message) return await func(client, message)
return wrapper return wrapper

View File

@@ -1,5 +1,7 @@
import os import os
from functools import cached_property
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from pydantic import computed_field
class Settings(BaseSettings): class Settings(BaseSettings):
@@ -46,10 +48,22 @@ class Settings(BaseSettings):
api_id: int api_id: int
api_hash: str api_hash: str
bot_token: str bot_token: str
storage: os.path storage: str
allowed_ids: str allowed_ids: str
log_level: str log_level: str
@computed_field
@property
def tt_hash_file(self) -> str:
return os.path.join(settings.storage, "tt_hashes.pickle")
@computed_field
@cached_property
def allowed_ids_list(self) -> list:
allowed_ids = settings.allowed_ids.split(",")
allowed_ids = [int(x) for x in allowed_ids]
return allowed_ids
class Config: class Config:
env_file = ".env" env_file = ".env"

View File

@@ -1,11 +1,15 @@
import os import os
import pickle # nosec
import re import re
from datetime import datetime from datetime import datetime
from hashlib import sha256
from pyrogram import Client from pyrogram import Client
from pyrogram.types import Message from pyrogram.types import Message
from tiktok_downloader import snaptik from tiktok_downloader import snaptik
from telegram_downloader_bot.settings import settings
def sanitize_name(input: str) -> str: def sanitize_name(input: str) -> str:
"""Sanize string by removing non aplhanumeric characters and spaces.""" """Sanize string by removing non aplhanumeric characters and spaces."""
@@ -14,7 +18,7 @@ def sanitize_name(input: str) -> str:
return output return output
def get_user_folder(storage_path: os.path, message: Message) -> os.path: def get_user_folder(message: Message) -> os.path:
""" Determine folder name used to save the media to. Depending on """ Determine folder name used to save the media to. Depending on
which type of message (forwarded, direct) detect that person's which type of message (forwarded, direct) detect that person's
or group's name.""" or group's name."""
@@ -47,19 +51,18 @@ def get_user_folder(storage_path: os.path, message: Message) -> os.path:
# Sanitize the folder name # Sanitize the folder name
user_folder_name = sanitize_name(user_folder_name) user_folder_name = sanitize_name(user_folder_name)
user_folder = os.path.join(storage_path, "telegram", user_folder_name) user_folder = os.path.join(settings.storage, "telegram", user_folder_name)
os.makedirs(user_folder, exist_ok=True) os.makedirs(user_folder, exist_ok=True)
return user_folder return user_folder
async def handle_media_message_contents(storage_path: os.path, async def handle_media_message_contents(client: Client,
client: Client,
message: Message): message: Message):
"""Detect what kind of media is being sent over from the user. """Detect what kind of media is being sent over from the user.
Based on that, determine the correct file extension and save Based on that, determine the correct file extension and save
that media.""" that media."""
user_folder = get_user_folder(storage_path, message) user_folder = get_user_folder(message)
# Handle documents # Handle documents
if message.document: if message.document:
@@ -94,7 +97,35 @@ async def handle_media_message_contents(storage_path: os.path,
await message.reply_text("Unknown media type!") await message.reply_text("Unknown media type!")
def download_tt_video(storage_path: str, url: str) -> None: def get_tt_hashes() -> set:
if not os.path.exists(settings.tt_hash_file):
return set()
with open(settings.tt_hash_file, "rb+") as f:
all_tt_hashes: set = pickle.load(f) # nosec
print(all_tt_hashes)
return all_tt_hashes
def add_to_hashes(new_hash: str) -> None:
all_tt_hashes = get_tt_hashes()
all_tt_hashes.add(new_hash)
save_tt_hashes(all_tt_hashes)
def save_tt_hashes(hashes: set) -> None:
with open(settings.tt_hash_file, "wb+") as f:
pickle.dump(hashes,
f,
protocol=pickle.HIGHEST_PROTOCOL)
def check_if_tt_downloaded(tt_hash: str) -> bool:
all_tt_hashes = get_tt_hashes()
return tt_hash in all_tt_hashes
def download_tt_video(url: str) -> str:
"""Downloads tiktok video from a given URL. """Downloads tiktok video from a given URL.
Makes sure the video integrity is correct.""" Makes sure the video integrity is correct."""
@@ -103,14 +134,24 @@ def download_tt_video(storage_path: str, url: str) -> None:
for video in videos: for video in videos:
video_filename = now.strftime("video-tiktok-%Y-%m-%d_%H-%M-%S.mp4") video_filename = now.strftime("video-tiktok-%Y-%m-%d_%H-%M-%S.mp4")
video_filepath: os.path = os.path.join(storage_path, video_filepath: os.path = os.path.join(settings.storage,
"tiktok", "tiktok",
video_filename) video_filename)
video_content = video.download().getbuffer() video_content = video.download().getbuffer()
video_hash = sha256(video_content).hexdigest()
print(video_hash)
if check_if_tt_downloaded(video_hash):
return "Already downloaded"
with open(video_filepath, "wb") as f: with open(video_filepath, "wb") as f:
f.write(video_content) f.write(video_content)
add_to_hashes(video_hash)
return "Downloaded ok"
def make_fs(storaga_path: str) -> None: def make_fs(storaga_path: str) -> None:
os.makedirs(os.path.join(storaga_path, "tiktok"), exist_ok=True) os.makedirs(os.path.join(storaga_path, "tiktok"), exist_ok=True)

View File

@@ -49,8 +49,11 @@ class TestGetUserFolder(unittest.TestCase):
def setUp(self): def setUp(self):
# Create a temporary directory for each test # Create a temporary directory for each test
self.tmp_path = tempfile.mkdtemp() self.tmp_path = tempfile.mkdtemp()
self.settings_patcher = patch('telegram_downloader_bot.settings.settings.storage', self.tmp_path)
self.settings_patcher.start()
def tearDown(self): def tearDown(self):
self.settings_patcher.stop()
# Remove the directory after the test # Remove the directory after the test
shutil.rmtree(self.tmp_path) shutil.rmtree(self.tmp_path)
@@ -65,7 +68,7 @@ class TestGetUserFolder(unittest.TestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = None message.from_user = None
result = get_user_folder(self.tmp_path, message) result = get_user_folder(message)
expected_folder = os.path.join(self.tmp_path, "telegram", "John_Doe") expected_folder = os.path.join(self.tmp_path, "telegram", "John_Doe")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
@@ -81,7 +84,7 @@ class TestGetUserFolder(unittest.TestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = None message.from_user = None
result = get_user_folder(self.tmp_path, message) result = get_user_folder(message)
expected_folder = os.path.join(self.tmp_path, "telegram", "12345") expected_folder = os.path.join(self.tmp_path, "telegram", "12345")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
@@ -95,7 +98,7 @@ class TestGetUserFolder(unittest.TestCase):
message.forward_from_chat = chat message.forward_from_chat = chat
message.from_user = None message.from_user = None
result = get_user_folder(self.tmp_path, message) result = get_user_folder(message)
expected_folder = os.path.join( expected_folder = os.path.join(
self.tmp_path, "telegram", "My_Awesome_GroupChat" self.tmp_path, "telegram", "My_Awesome_GroupChat"
) )
@@ -113,7 +116,7 @@ class TestGetUserFolder(unittest.TestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = user message.from_user = user
result = get_user_folder(self.tmp_path, message) result = get_user_folder(message)
expected_folder = os.path.join(self.tmp_path, "telegram", "Jane_Doe") expected_folder = os.path.join(self.tmp_path, "telegram", "Jane_Doe")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
@@ -129,7 +132,7 @@ class TestGetUserFolder(unittest.TestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = user message.from_user = user
result = get_user_folder(self.tmp_path, message) result = get_user_folder(message)
expected_folder = os.path.join(self.tmp_path, "telegram", "54321") expected_folder = os.path.join(self.tmp_path, "telegram", "54321")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
@@ -139,8 +142,12 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
def setUp(self): def setUp(self):
# Create a temporary directory for each test # Create a temporary directory for each test
self.tmp_path = tempfile.mkdtemp() self.tmp_path = tempfile.mkdtemp()
self.settings_patcher = patch('telegram_downloader_bot.settings.settings.storage', self.tmp_path)
self.settings_patcher.start()
def tearDown(self): def tearDown(self):
# Stop patching settings.storage
self.settings_patcher.stop()
# Remove the directory after the test # Remove the directory after the test
shutil.rmtree(self.tmp_path) shutil.rmtree(self.tmp_path)
@@ -161,7 +168,7 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
message.animation = None message.animation = None
message.reply_text = AsyncMock() message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(client, message)
expected_file_name = f"video_{message.video.file_id}.mp4" expected_file_name = f"video_{message.video.file_id}.mp4"
expected_file_path = os.path.join(user_folder, expected_file_name) expected_file_path = os.path.join(user_folder, expected_file_name)
@@ -187,7 +194,7 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
message.animation.file_id = "animation_file_id" message.animation.file_id = "animation_file_id"
message.reply_text = AsyncMock() message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(client, message)
expected_file_name = f"gif_{message.animation.file_id}.gif" expected_file_name = f"gif_{message.animation.file_id}.gif"
expected_file_path = os.path.join(user_folder, expected_file_name) expected_file_path = os.path.join(user_folder, expected_file_name)
@@ -213,7 +220,7 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
message.animation = None message.animation = None
message.reply_text = AsyncMock() message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(client, message)
expected_file_path = os.path.join(user_folder, "test_document.pdf") expected_file_path = os.path.join(user_folder, "test_document.pdf")
client.download_media.assert_awaited_once_with( client.download_media.assert_awaited_once_with(
@@ -238,7 +245,7 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
message.animation = None message.animation = None
message.reply_text = AsyncMock() message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(client, message)
expected_file_name = f"photo_{message.photo.file_id}.jpg" expected_file_name = f"photo_{message.photo.file_id}.jpg"
expected_file_path = os.path.join(user_folder, expected_file_name) expected_file_path = os.path.join(user_folder, expected_file_name)
@@ -263,7 +270,7 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
message.animation = None message.animation = None
message.reply_text = AsyncMock() message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(client, message)
client.download_media.assert_not_called() client.download_media.assert_not_called()
message.reply_text.assert_awaited_once_with("Unknown media type!") message.reply_text.assert_awaited_once_with("Unknown media type!")
@@ -274,6 +281,8 @@ class TestDownloadTTVideo(unittest.TestCase):
# Create a temporary directory for each test # Create a temporary directory for each test
self.tmp_path = tempfile.mkdtemp() self.tmp_path = tempfile.mkdtemp()
os.makedirs(os.path.join(self.tmp_path, "tiktok"), exist_ok=True) os.makedirs(os.path.join(self.tmp_path, "tiktok"), exist_ok=True)
self.settings_patcher = patch("telegram_downloader_bot.settings.settings.storage", self.tmp_path)
self.settings_patcher.start()
# Paths to the valid and invalid video files # Paths to the valid and invalid video files
self.valid_video_path = os.path.join(self.tmp_path, "valid.mp4") self.valid_video_path = os.path.join(self.tmp_path, "valid.mp4")
@@ -285,6 +294,7 @@ class TestDownloadTTVideo(unittest.TestCase):
f.write(b'invalid mp4 content') f.write(b'invalid mp4 content')
def tearDown(self): def tearDown(self):
self.settings_patcher.stop()
# Remove the directory after the test # Remove the directory after the test
shutil.rmtree(self.tmp_path) shutil.rmtree(self.tmp_path)
@@ -305,7 +315,7 @@ class TestDownloadTTVideo(unittest.TestCase):
mock_snaptik.return_value = [mock_video] mock_snaptik.return_value = [mock_video]
# Call the function # Call the function
download_tt_video(self.tmp_path, "http://tiktok.com/video123") download_tt_video("http://tiktok.com/video123")
# Verify that the file was saved correctly # Verify that the file was saved correctly
video_filename = mock_now.strftime( video_filename = mock_now.strftime(
@@ -334,7 +344,7 @@ class TestDownloadTTVideo(unittest.TestCase):
mock_snaptik.return_value = [mock_video] mock_snaptik.return_value = [mock_video]
# Call the function # Call the function
download_tt_video(self.tmp_path, "http://tiktok.com/video123") download_tt_video("http://tiktok.com/video123")
# Verify that the file was saved # Verify that the file was saved
video_filename = mock_now.strftime( video_filename = mock_now.strftime(
@@ -357,7 +367,7 @@ class TestDownloadTTVideo(unittest.TestCase):
mock_snaptik.return_value = [] mock_snaptik.return_value = []
# Call the function # Call the function
download_tt_video(self.tmp_path, "http://tiktok.com/video123") download_tt_video("http://tiktok.com/video123")
# Verify that no files were created # Verify that no files were created
tiktok_folder = os.path.join(self.tmp_path, "tiktok") tiktok_folder = os.path.join(self.tmp_path, "tiktok")