vtt_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. """
  2. VTT (WebVTT) file parsing and manipulation utilities.
  3. """
  4. import re
  5. from dataclasses import dataclass
  6. from typing import List, Tuple
  7. from datetime import timedelta
  8. @dataclass
  9. class Subtitle:
  10. """Represents a single subtitle entry."""
  11. start_time: str
  12. end_time: str
  13. text: str
  14. @property
  15. def start_seconds(self) -> float:
  16. """Convert start time to seconds."""
  17. return self._time_to_seconds(self.start_time)
  18. @property
  19. def end_seconds(self) -> float:
  20. """Convert end time to seconds."""
  21. return self._time_to_seconds(self.end_time)
  22. @staticmethod
  23. def _time_to_seconds(time_str: str) -> float:
  24. """Convert HH:MM:SS.mmm to seconds."""
  25. parts = time_str.replace(',', '.').split(':')
  26. hours = int(parts[0])
  27. minutes = int(parts[1])
  28. seconds = float(parts[2])
  29. return hours * 3600 + minutes * 60 + seconds
  30. def __str__(self) -> str:
  31. """Format as VTT subtitle entry."""
  32. return f"{self.start_time} --> {self.end_time}\n{self.text}"
  33. class VTTFile:
  34. """Handles VTT file parsing and manipulation."""
  35. def __init__(self, filepath: str):
  36. """Initialize VTT file handler."""
  37. self.filepath = filepath
  38. self.subtitles: List[Subtitle] = []
  39. self._parse()
  40. def _parse(self) -> None:
  41. """Parse VTT file into subtitle objects."""
  42. with open(self.filepath, 'r', encoding='utf-8') as f:
  43. content = f.read()
  44. # Split by double newlines to get subtitle blocks
  45. blocks = content.strip().split('\n\n')
  46. for block in blocks:
  47. # Skip the WEBVTT header
  48. if block.startswith('WEBVTT') or not block.strip():
  49. continue
  50. lines = block.strip().split('\n')
  51. if len(lines) < 2:
  52. continue
  53. # First line should be timestamps
  54. timestamp_line = lines[0]
  55. if '-->' not in timestamp_line:
  56. continue
  57. # Parse timestamps
  58. parts = timestamp_line.split('-->')
  59. if len(parts) != 2:
  60. continue
  61. start_time = parts[0].strip()
  62. end_time = parts[1].strip()
  63. # Join remaining lines as text
  64. text = '\n'.join(lines[1:])
  65. self.subtitles.append(Subtitle(
  66. start_time=start_time,
  67. end_time=end_time,
  68. text=text
  69. ))
  70. def get_duration(self) -> Tuple[int, int]:
  71. """
  72. Get total duration of VTT file in (minutes, hours).
  73. Returns tuple of (total_minutes, total_hours).
  74. """
  75. if not self.subtitles:
  76. return 0, 0
  77. last_subtitle = self.subtitles[-1]
  78. total_seconds = last_subtitle.end_seconds
  79. total_minutes = int(total_seconds / 60)
  80. total_hours = total_minutes / 60
  81. return total_minutes, total_hours
  82. def to_string(self) -> str:
  83. """Convert VTT file back to string format."""
  84. lines = ['WEBVTT', '']
  85. for subtitle in self.subtitles:
  86. lines.append(str(subtitle))
  87. lines.append('')
  88. return '\n'.join(lines)
  89. def save(self, filepath: str) -> None:
  90. """Save VTT file to disk."""
  91. with open(filepath, 'w', encoding='utf-8') as f:
  92. f.write(self.to_string())
  93. def get_subtitle_range(self, start_idx: int, end_idx: int) -> 'VTTFile':
  94. """
  95. Create a new VTTFile with subtitles in the specified range.
  96. Returns a new VTTFile object containing subtitles[start_idx:end_idx].
  97. """
  98. new_file = VTTFile.__new__(VTTFile)
  99. new_file.filepath = self.filepath
  100. new_file.subtitles = self.subtitles[start_idx:end_idx]
  101. return new_file
  102. def estimate_token_count(text: str, avg_tokens_per_word: float = 1.3) -> int:
  103. """
  104. Rough estimate of token count using word count.
  105. Japanese typically has 1.3 tokens per word with most tokenizers.
  106. """
  107. words = len(text.split())
  108. return int(words * avg_tokens_per_word)
  109. def has_japanese_characters(text: str) -> bool:
  110. """Check if text contains Japanese characters (Hiragana, Katakana, Kanji)."""
  111. # Japanese Unicode ranges:
  112. # Hiragana: 3040-309F
  113. # Katakana: 30A0-30FF
  114. # Kanji: 4E00-9FFF
  115. japanese_pattern = r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF]'
  116. return bool(re.search(japanese_pattern, text))