Fix unit tests

This commit is contained in:
Roman Krček
2024-10-13 18:15:05 +02:00
parent af6282e26d
commit 2f3a2d1700

View File

@@ -1,13 +1,16 @@
# test_utils.py
import unittest import unittest
import os import os
import re import re
import asyncio
import tempfile
import shutil import shutil
import tempfile
from unittest.mock import Mock, AsyncMock, patch from unittest.mock import Mock, AsyncMock, patch
from datetime import datetime from datetime import datetime
# Adjusted import statement
from telegram_downloader_bot.utils import ( from telegram_downloader_bot.utils import (
sanitize_name,
get_user_folder, get_user_folder,
handle_media_message_contents, handle_media_message_contents,
download_tt_video, download_tt_video,
@@ -17,8 +20,32 @@ from telegram_downloader_bot.utils import (
) )
from pyrogram.types import Message, User, Chat from pyrogram.types import Message, User, Chat
from pyrogram import Client
class TestGetUserFolder(unittest.IsolatedAsyncioTestCase):
class TestSanitizeName(unittest.TestCase):
def test_alphanumeric_input(self):
input_str = "JohnDoe123"
expected_output = "JohnDoe123"
self.assertEqual(sanitize_name(input_str), expected_output)
def test_input_with_special_chars(self):
input_str = "John Doe!@#"
expected_output = "John_Doe"
self.assertEqual(sanitize_name(input_str), expected_output)
def test_input_with_only_special_chars(self):
input_str = "!@#$%^&*()"
expected_output = ""
self.assertEqual(sanitize_name(input_str), expected_output)
def test_empty_input(self):
input_str = ""
expected_output = ""
self.assertEqual(sanitize_name(input_str), expected_output)
class TestGetUserFolder(unittest.TestCase):
def setUp(self): def setUp(self):
# Create a temporary directory for each test # Create a temporary directory for each test
self.tmp_path = tempfile.mkdtemp() self.tmp_path = tempfile.mkdtemp()
@@ -27,7 +54,7 @@ class TestGetUserFolder(unittest.IsolatedAsyncioTestCase):
# Remove the directory after the test # Remove the directory after the test
shutil.rmtree(self.tmp_path) shutil.rmtree(self.tmp_path)
async def test_forward_from_full_name(self): def test_forward_from_full_name(self):
user = Mock() user = Mock()
user.first_name = "John" user.first_name = "John"
user.last_name = "Doe" user.last_name = "Doe"
@@ -38,12 +65,12 @@ class TestGetUserFolder(unittest.IsolatedAsyncioTestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = None message.from_user = None
result = await get_user_folder(self.tmp_path, message) result = get_user_folder(self.tmp_path, message)
expected_folder = os.path.join(self.tmp_path, "telegram", "John_Doe") expected_folder = os.path.join(self.tmp_path, "telegram", "John_Doe")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
async def test_forward_from_first_name_only(self): def test_forward_from_first_name_only(self):
user = Mock() user = Mock()
user.first_name = "John" user.first_name = "John"
user.last_name = None user.last_name = None
@@ -54,28 +81,12 @@ class TestGetUserFolder(unittest.IsolatedAsyncioTestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = None message.from_user = None
result = await get_user_folder(self.tmp_path, message) result = get_user_folder(self.tmp_path, message)
expected_folder = os.path.join(self.tmp_path, "telegram", "12345") expected_folder = os.path.join(self.tmp_path, "telegram", "12345")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
async def test_forward_from_no_name(self): def test_forward_from_chat_title(self):
user = Mock()
user.first_name = None
user.last_name = None
user.id = 12345
message = Mock()
message.forward_from = user
message.forward_from_chat = None
message.from_user = None
result = await get_user_folder(self.tmp_path, message)
expected_folder = os.path.join(self.tmp_path, "telegram", "12345")
self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder))
async def test_forward_from_chat_special_chars(self):
chat = Mock() chat = Mock()
chat.title = "My *Awesome* Group/Chat!" chat.title = "My *Awesome* Group/Chat!"
@@ -84,14 +95,14 @@ class TestGetUserFolder(unittest.IsolatedAsyncioTestCase):
message.forward_from_chat = chat message.forward_from_chat = chat
message.from_user = None message.from_user = None
result = await get_user_folder(self.tmp_path, message) result = get_user_folder(self.tmp_path, message)
expected_folder = os.path.join( expected_folder = os.path.join(
self.tmp_path, "telegram", "My_Awesome_GroupChat" self.tmp_path, "telegram", "My_Awesome_GroupChat"
) )
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
async def test_from_user_full_name(self): def test_from_user_full_name(self):
user = Mock() user = Mock()
user.first_name = "Jane" user.first_name = "Jane"
user.last_name = "Doe" user.last_name = "Doe"
@@ -102,14 +113,14 @@ class TestGetUserFolder(unittest.IsolatedAsyncioTestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = user message.from_user = user
result = await get_user_folder(self.tmp_path, message) result = get_user_folder(self.tmp_path, message)
expected_folder = os.path.join(self.tmp_path, "telegram", "Jane_Doe") expected_folder = os.path.join(self.tmp_path, "telegram", "Jane_Doe")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
async def test_from_user_first_name_only(self): def test_from_user_id(self):
user = Mock() user = Mock()
user.first_name = "Jane" user.first_name = None
user.last_name = None user.last_name = None
user.id = 54321 user.id = 54321
@@ -118,27 +129,11 @@ class TestGetUserFolder(unittest.IsolatedAsyncioTestCase):
message.forward_from_chat = None message.forward_from_chat = None
message.from_user = user message.from_user = user
result = await get_user_folder(self.tmp_path, message) result = get_user_folder(self.tmp_path, message)
expected_folder = os.path.join(self.tmp_path, "telegram", "54321") expected_folder = os.path.join(self.tmp_path, "telegram", "54321")
self.assertEqual(result, expected_folder) self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder)) self.assertTrue(os.path.exists(expected_folder))
async def test_special_characters_in_name(self):
user = Mock()
user.first_name = "Ja*ne"
user.last_name = "Do/e"
user.id = 54321
message = Mock()
message.forward_from = None
message.forward_from_chat = None
message.from_user = user
result = await get_user_folder(self.tmp_path, message)
expected_folder = os.path.join(self.tmp_path, "telegram", "Jane_Doe")
self.assertEqual(result, expected_folder)
self.assertTrue(os.path.exists(expected_folder))
class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase): class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
def setUp(self): def setUp(self):
@@ -149,16 +144,68 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
# Remove the directory after the test # Remove the directory after the test
shutil.rmtree(self.tmp_path) shutil.rmtree(self.tmp_path)
@patch("telegram_downloader_bot.utils.get_user_folder") @patch('telegram_downloader_bot.utils.get_user_folder')
async def test_document(self, mock_get_user_folder): async def test_handle_video(self, mock_get_user_folder):
user_folder = os.path.join(self.tmp_path, "user_folder") user_folder = os.path.join(self.tmp_path, "user_folder")
mock_get_user_folder.return_value = user_folder mock_get_user_folder.return_value = user_folder
os.makedirs(user_folder, exist_ok=True) os.makedirs(user_folder, exist_ok=True)
client = Mock() client = Mock(spec=Client)
client.download_media = AsyncMock() client.download_media = AsyncMock()
message = Mock() message = Mock(spec=Message)
message.document = None
message.photo = None
message.video = Mock()
message.video.file_id = "video_file_id"
message.animation = None
message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message)
expected_file_name = f"video_{message.video.file_id}.mp4"
expected_file_path = os.path.join(user_folder, expected_file_name)
client.download_media.assert_awaited_once_with(
message, expected_file_path)
message.reply_text.assert_awaited_once_with(
f"Video saved to {user_folder}")
@patch('telegram_downloader_bot.utils.get_user_folder')
async def test_handle_animation(self, mock_get_user_folder):
user_folder = os.path.join(self.tmp_path, "user_folder")
mock_get_user_folder.return_value = user_folder
os.makedirs(user_folder, exist_ok=True)
client = Mock(spec=Client)
client.download_media = AsyncMock()
message = Mock(spec=Message)
message.document = None
message.photo = None
message.video = None
message.animation = Mock()
message.animation.file_id = "animation_file_id"
message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message)
expected_file_name = f"gif_{message.animation.file_id}.gif"
expected_file_path = os.path.join(user_folder, expected_file_name)
client.download_media.assert_awaited_once_with(
message.animation, expected_file_path)
message.reply_text.assert_awaited_once_with(
f"GIF saved to {user_folder}")
@patch('telegram_downloader_bot.utils.get_user_folder')
async def test_handle_document(self, mock_get_user_folder):
user_folder = os.path.join(self.tmp_path, "user_folder")
mock_get_user_folder.return_value = user_folder
os.makedirs(user_folder, exist_ok=True)
client = Mock(spec=Client)
client.download_media = AsyncMock()
message = Mock(spec=Message)
message.document = Mock() message.document = Mock()
message.document.file_name = "test_document.pdf" message.document.file_name = "test_document.pdf"
message.photo = None message.photo = None
@@ -168,44 +215,48 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(self.tmp_path, client, message)
expected_file_path = os.path.join(user_folder, "test_document.pdf")
client.download_media.assert_awaited_once_with( client.download_media.assert_awaited_once_with(
message, os.path.join(user_folder, "test_document.pdf") message, expected_file_path)
) message.reply_text.assert_awaited_once_with(
message.reply_text.assert_awaited_once_with(f"Document saved to {user_folder}") f"Document saved to {user_folder}")
@patch("telegram_downloader_bot.utils.get_user_folder") @patch('telegram_downloader_bot.utils.get_user_folder')
async def test_photo(self, mock_get_user_folder): async def test_handle_photo(self, mock_get_user_folder):
user_folder = os.path.join(self.tmp_path, "user_folder") user_folder = os.path.join(self.tmp_path, "user_folder")
mock_get_user_folder.return_value = user_folder mock_get_user_folder.return_value = user_folder
os.makedirs(user_folder, exist_ok=True) os.makedirs(user_folder, exist_ok=True)
client = Mock() client = Mock(spec=Client)
client.download_media = AsyncMock() client.download_media = AsyncMock()
message = Mock() message = Mock(spec=Message)
message.document = None message.document = None
message.photo = Mock() message.photo = Mock()
message.photo.file_id = "1234567890" message.photo.file_id = "photo_file_id"
message.video = None message.video = None
message.animation = None message.animation = None
message.reply_text = AsyncMock() message.reply_text = AsyncMock()
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(self.tmp_path, client, message)
expected_file = os.path.join(user_folder, f"photo_{message.photo.file_id}.jpg") expected_file_name = f"photo_{message.photo.file_id}.jpg"
client.download_media.assert_awaited_once_with(message.photo, expected_file) expected_file_path = os.path.join(user_folder, expected_file_name)
message.reply_text.assert_awaited_once_with(f"Photo saved to {user_folder}") client.download_media.assert_awaited_once_with(
message.photo, expected_file_path)
message.reply_text.assert_awaited_once_with(
f"Photo saved to {user_folder}")
@patch("telegram_downloader_bot.utils.get_user_folder") @patch('telegram_downloader_bot.utils.get_user_folder')
async def test_unknown_media(self, mock_get_user_folder): async def test_handle_unknown_media(self, mock_get_user_folder):
user_folder = os.path.join(self.tmp_path, "user_folder") user_folder = os.path.join(self.tmp_path, "user_folder")
mock_get_user_folder.return_value = user_folder mock_get_user_folder.return_value = user_folder
os.makedirs(user_folder, exist_ok=True) os.makedirs(user_folder, exist_ok=True)
client = Mock() client = Mock(spec=Client)
client.download_media = AsyncMock() client.download_media = AsyncMock()
message = Mock() message = Mock(spec=Message)
message.document = None message.document = None
message.photo = None message.photo = None
message.video = None message.video = None
@@ -214,7 +265,7 @@ class TestHandleMediaMessageContents(unittest.IsolatedAsyncioTestCase):
await handle_media_message_contents(self.tmp_path, client, message) await handle_media_message_contents(self.tmp_path, client, message)
client.download_media.assert_not_awaited() client.download_media.assert_not_called()
message.reply_text.assert_awaited_once_with("Unknown media type!") message.reply_text.assert_awaited_once_with("Unknown media type!")
@@ -222,92 +273,145 @@ class TestDownloadTTVideo(unittest.TestCase):
def setUp(self): def setUp(self):
# Create a temporary directory for each test # Create a temporary directory for each test
self.tmp_path = tempfile.mkdtemp() self.tmp_path = tempfile.mkdtemp()
os.makedirs(os.path.join(self.tmp_path, "tiktok"), exist_ok=True)
# Paths to the valid and invalid video files
self.valid_video_path = os.path.join(self.tmp_path, "valid.mp4")
with open(self.valid_video_path, 'wb') as f:
f.write(b'valid mp4 content')
self.invalid_video_path = os.path.join(self.tmp_path, "invalid.mp4")
with open(self.invalid_video_path, 'wb') as f:
f.write(b'invalid mp4 content')
def tearDown(self): def tearDown(self):
# Remove the directory after the test # Remove the directory after the test
shutil.rmtree(self.tmp_path) shutil.rmtree(self.tmp_path)
@patch("telegram_downloader_bot.utils.snaptik") @patch('telegram_downloader_bot.utils.snaptik')
@patch("telegram_downloader_bot.utils.integv.verify") @patch('telegram_downloader_bot.utils.datetime')
@patch("telegram_downloader_bot.utils.datetime") def test_download_tt_video_with_valid_video(self, mock_datetime, mock_snaptik):
def test_success(self, mock_datetime, mock_verify, mock_snaptik): # Mock datetime
mock_video = Mock()
mock_video.download.return_value.getbuffer.return_value = b"video_content"
mock_snaptik.return_value = [mock_video]
mock_verify.return_value = True
mock_now = datetime(2023, 1, 1, 12, 0, 0) mock_now = datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.now.return_value = mock_now mock_datetime.datetime.now.return_value = mock_now
result = download_tt_video(self.tmp_path, "http://tiktok.com/video123") # Read the content of valid.mp4
self.assertTrue(result) with open(self.valid_video_path, 'rb') as f:
valid_video_content = f.read()
video_filename = mock_now.strftime("video-tiktok-%Y-%m-%d_%H-%M-%S.mp4") # Mock snaptik to return a video that returns valid.mp4 content
mock_video = Mock()
mock_video.download.return_value.getbuffer.return_value = valid_video_content
mock_snaptik.return_value = [mock_video]
# Call the function
download_tt_video(self.tmp_path, "http://tiktok.com/video123")
# Verify that the file was saved correctly
video_filename = mock_now.strftime(
"video-tiktok-%Y-%m-%d_%H-%M-%S.mp4")
video_filepath = os.path.join(self.tmp_path, "tiktok", video_filename) video_filepath = os.path.join(self.tmp_path, "tiktok", video_filename)
self.assertTrue(os.path.exists(video_filepath)) self.assertTrue(os.path.exists(video_filepath))
@patch("telegram_downloader_bot.utils.snaptik") with open(video_filepath, 'rb') as f:
@patch("telegram_downloader_bot.utils.integv.verify") content = f.read()
@patch("telegram_downloader_bot.utils.datetime") self.assertEqual(content, valid_video_content)
def test_failure(self, mock_datetime, mock_verify, mock_snaptik):
mock_video = Mock()
mock_video.download.return_value.getbuffer.return_value = b"video_content"
mock_snaptik.return_value = [mock_video]
mock_verify.return_value = False
@patch('telegram_downloader_bot.utils.snaptik')
@patch('telegram_downloader_bot.utils.datetime')
def test_download_tt_video_with_invalid_video(self, mock_datetime, mock_snaptik):
# Mock datetime
mock_now = datetime(2023, 1, 1, 12, 0, 0) mock_now = datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.now.return_value = mock_now mock_datetime.datetime.now.return_value = mock_now
result = download_tt_video(self.tmp_path, "http://tiktok.com/video123") # Read the content of invalid.mp4
self.assertFalse(result) with open(self.invalid_video_path, 'rb') as f:
invalid_video_content = f.read()
video_filename = mock_now.strftime("video-tiktok-%Y-%m-%d_%H-%M-%S.mp4") # Mock snaptik to return a video that returns invalid.mp4 content
mock_video = Mock()
mock_video.download.return_value.getbuffer.return_value = invalid_video_content
mock_snaptik.return_value = [mock_video]
# Call the function
download_tt_video(self.tmp_path, "http://tiktok.com/video123")
# Verify that the file was saved
video_filename = mock_now.strftime(
"video-tiktok-%Y-%m-%d_%H-%M-%S.mp4")
video_filepath = os.path.join(self.tmp_path, "tiktok", video_filename) video_filepath = os.path.join(self.tmp_path, "tiktok", video_filename)
self.assertFalse(os.path.exists(video_filepath)) self.assertTrue(os.path.exists(video_filepath))
with open(video_filepath, 'rb') as f:
content = f.read()
self.assertEqual(content, invalid_video_content)
@patch('telegram_downloader_bot.utils.snaptik')
@patch('telegram_downloader_bot.utils.datetime')
def test_download_tt_video_no_videos(self, mock_datetime, mock_snaptik):
# Mock datetime
mock_now = datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = mock_now
# Mock snaptik to return an empty list
mock_snaptik.return_value = []
# Call the function
download_tt_video(self.tmp_path, "http://tiktok.com/video123")
# Verify that no files were created
tiktok_folder = os.path.join(self.tmp_path, "tiktok")
files = os.listdir(tiktok_folder)
self.assertEqual(len(files), 0)
class TestMakeFS(unittest.TestCase): class TestMakeFS(unittest.TestCase):
def setUp(self): def setUp(self):
# Create a temporary directory for each test
self.tmp_path = tempfile.mkdtemp() self.tmp_path = tempfile.mkdtemp()
def tearDown(self): def tearDown(self):
# Remove the directory after the test
shutil.rmtree(self.tmp_path) shutil.rmtree(self.tmp_path)
def test_make_fs(self): def test_make_fs(self):
make_fs(self.tmp_path) make_fs(self.tmp_path)
self.assertTrue(os.path.exists(os.path.join(self.tmp_path, "tiktok"))) self.assertTrue(os.path.exists(os.path.join(self.tmp_path, "tiktok")))
self.assertTrue(os.path.exists(os.path.join(self.tmp_path, "telegram"))) self.assertTrue(os.path.exists(
os.path.join(self.tmp_path, "telegram")))
class TestExtractURLs(unittest.TestCase): class TestExtractURLs(unittest.TestCase):
def test_no_urls(self): def test_no_urls(self):
result = extract_urls("This is some text without any URLs.") text = "This is some text without any URLs."
result = extract_urls(text)
self.assertEqual(result, []) self.assertEqual(result, [])
def test_one_url(self): def test_single_url(self):
result = extract_urls("Check out this link: http://example.com") text = "Check out this link: http://example.com"
result = extract_urls(text)
self.assertEqual(result, ["http://example.com"]) self.assertEqual(result, ["http://example.com"])
def test_multiple_urls(self): def test_multiple_urls(self):
result = extract_urls( text = "Here are some links: http://example.com and https://test.com/page"
"Here are some links: http://example.com and https://test.com/page" result = extract_urls(text)
) self.assertEqual(
self.assertEqual(result, ["http://example.com", "https://test.com/page"]) result, ["http://example.com", "https://test.com/page"])
def test_malformed_url(self): def test_malformed_url(self):
result = extract_urls("This is not a URL: htt://badurl.com") text = "This is not a URL: htt://badurl.com"
result = extract_urls(text)
self.assertEqual(result, []) self.assertEqual(result, [])
def test_url_at_text_boundaries(self): def test_urls_with_special_chars(self):
result = extract_urls("http://start.com text in the middle https://end.com") text = "Link: https://example.com/page?param=value#anchor"
self.assertEqual(result, ["http://start.com", "https://end.com"]) result = extract_urls(text)
self.assertEqual(
result, ["https://example.com/page?param=value#anchor"])
class TestFilterTTURLs(unittest.TestCase): class TestFilterTTURLs(unittest.TestCase):
def test_empty_list(self): def test_empty_list(self):
result = filter_tt_urls([]) urls = []
result = filter_tt_urls(urls)
self.assertEqual(result, []) self.assertEqual(result, [])
def test_no_tiktok_urls(self): def test_no_tiktok_urls(self):
@@ -315,17 +419,22 @@ class TestFilterTTURLs(unittest.TestCase):
result = filter_tt_urls(urls) result = filter_tt_urls(urls)
self.assertEqual(result, []) self.assertEqual(result, [])
def test_only_tiktok_urls(self):
urls = ["http://tiktok.com/video1", "https://www.tiktok.com/@user/video/123"]
result = filter_tt_urls(urls)
self.assertEqual(result, urls)
def test_mixed_urls(self): def test_mixed_urls(self):
urls = ["http://example.com", "https://www.tiktok.com/@user/video/123"] urls = [
"http://example.com",
"https://www.tiktok.com/@user/video/123",
"http://tiktok.com/video1",
"https://test.com/page",
]
expected = [
"https://www.tiktok.com/@user/video/123",
"http://tiktok.com/video1",
]
result = filter_tt_urls(urls) result = filter_tt_urls(urls)
self.assertEqual(result, ["https://www.tiktok.com/@user/video/123"]) self.assertEqual(result, expected)
def test_tiktok_in_query(self): def test_tiktok_in_query_params(self):
urls = ["http://example.com?param=tiktok", "https://www.other.com/path"] urls = ["http://example.com?watch=tiktok", "https://other.com/path"]
expected = ["http://example.com?watch=tiktok"]
result = filter_tt_urls(urls) result = filter_tt_urls(urls)
self.assertEqual(result, ["http://example.com?param=tiktok"]) self.assertEqual(result, expected)