import sys import os import threading import cv2 import numpy as np import torch import torchvision.transforms as transforms from PIL import Image from PIL.Image import fromarray as ImageFromArray import tempfile import shutil # Imports from qai_appbuilder from qai_appbuilder import (QNNContext, Runtime, LogLevel, ProfilingLevel, PerfProfile, QNNConfig) # Imports from PySide6 from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QProgressBar, QSlider, QMessageBox, QSizePolicy, QDialog, QRadioButton, QGroupBox) from PySide6.QtCore import (Qt, Signal, Slot, QPoint, QPointF, QTimer) from PySide6.QtGui import (QImage, QPixmap, QCursor, QPainter, QFont) # --- START OF MERGED real_esrgan_x4plus.py CONTENT --- sys.path.append(".") sys.path.append("..") #################################################################### MODEL_ID = "mnz1l2exq" MODEL_NAME = "real_esrgan_x4plus" MODEL_HELP_URL = "https://github.com/quic/ai-engine-direct-helper/tree/main/samples/python/" + MODEL_NAME + "#" + MODEL_NAME + "-qnn-models" IMAGE_SIZE = 512 #################################################################### execution_ws = os.getcwd() # 支持pyinstaller打包路径 if getattr(sys, 'frozen', False): base_dir = os.path.dirname(sys.executable) execution_ws = base_dir internal_dir = os.path.join(base_dir, "internal") if os.path.exists(internal_dir): execution_ws = internal_dir model_dir = os.path.join(execution_ws, "models") if not os.path.exists(model_dir): model_dir = os.path.join(os.path.dirname(execution_ws), "models") if not os.path.exists(model_dir): model_dir = os.path.join(execution_ws, "internal", "models") if not os.path.exists(model_dir): raise FileNotFoundError(f"Model directory not found: {model_dir}") qnn_dir = os.path.join(execution_ws, "qai_libs_2.32") if not os.path.exists(qnn_dir): qnn_dir = os.path.join(execution_ws, "internal", "qai_libs_2.32") if not os.path.exists(qnn_dir): raise FileNotFoundError(f"QNN library directory not found: {qnn_dir}") CURRENT_MODEL_NAME = "real_esrgan_general_x4v3" madel_path = os.path.join(model_dir, f"{CURRENT_MODEL_NAME}.bin") #################################################################### image_buffer = None realesrgan = None model_initialized = False def initialize_global_model(): global model_initialized, realesrgan, madel_path if model_initialized: release_global_model() madel_path = os.path.join(model_dir, f"{CURRENT_MODEL_NAME}.bin") if not os.path.exists(madel_path): raise FileNotFoundError(f"Model file not found: {madel_path}") try: QNNConfig.Config(qnn_dir, Runtime.HTP, LogLevel.WARN, ProfilingLevel.BASIC) realesrgan = QNNContext("realesrgan", madel_path) model_initialized = True except Exception as e: print(f"Global model initialization failed: {e}") model_initialized = False raise def release_global_model(): global model_initialized, realesrgan if model_initialized: try: realesrgan = None model_initialized = False except Exception as e: print(f"Global model release failed: {e}") def resize_and_pad(image: np.ndarray, target_size: tuple) -> tuple: h, w = image.shape[:2] target_h, target_w = target_size scale = min(target_h / h, target_w / w) new_h, new_w = int(h * scale), int(w * scale) resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) pad_top = (target_h - new_h) // 2 pad_bottom = target_h - new_h - pad_top pad_left = (target_w - new_w) // 2 pad_right = target_w - new_w - pad_left padded = cv2.copyMakeBorder(resized, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0) return padded, scale, (pad_top, pad_left) def undo_resize_and_pad(image: np.ndarray, original_size: tuple, scale: float, padding: tuple) -> np.ndarray: pad_top, pad_left = padding h, w = image.shape[:2] pad_bottom = pad_top pad_right = pad_left cropped = image[pad_top:h - pad_bottom, pad_left:w - pad_right] orig_w, orig_h = original_size resized = cv2.resize(cropped, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) return resized def Inference(input_image: np.ndarray) -> np.ndarray: global realesrgan if not model_initialized or realesrgan is None: raise RuntimeError("RealESRGan model not initialized") image, scale, padding = resize_and_pad(input_image, (IMAGE_SIZE, IMAGE_SIZE)) image = np.clip(image, 0, 255) / 255.0 image = image.astype(np.float32) PerfProfile.SetPerfProfileGlobal(PerfProfile.BURST) output_data = realesrgan.Inference([image])[0] PerfProfile.RelPerfProfileGlobal() output_image = output_data.reshape(IMAGE_SIZE * 4, IMAGE_SIZE * 4, 3) output_image = np.clip(output_image, 0.0, 1.0) output_image = (output_image * 255).astype(np.uint8) orig_size = (input_image.shape[1] * 4, input_image.shape[0] * 4) padding_scaled = (padding[0] * 4, padding[1] * 4) output_image = undo_resize_and_pad(output_image, orig_size, scale, padding_scaled) return output_image # --- END OF MERGED real_esrgan_x4plus.py CONTENT --- # --- Custom Draggable Label --- class DraggableLabel(QLabel): def __init__(self, parent=None, main_window=None): super().__init__(parent) self.main_window = main_window self.setMouseTracking(True) self.dragging = False self.offset = QPointF(0.0, 0.0) self.last_pos = None self.original_pixmap = None self.setStyleSheet("QLabel { background-color: #000000; border-radius: 8px; color: #FFFFFF; font-size: 16px; padding: 16px; }") def set_pixmap(self, pixmap, new_offset=None): self.original_pixmap = pixmap if new_offset is not None: self.offset = new_offset else: self.offset = QPointF(0.0, 0.0) self.update_display() def update_display(self): if not self.original_pixmap or self.original_pixmap.isNull(): super().setPixmap(QPixmap()) return pixmap = QPixmap(self.size()) pixmap.fill(Qt.transparent) painter = QPainter(pixmap) painter.drawPixmap(self.offset.toPoint(), self.original_pixmap) painter.end() super().setPixmap(pixmap) def enterEvent(self, event): if self == self.main_window.enhanced_label: is_processing_active = self.main_window.is_processing and not self.main_window.is_paused if not is_processing_active: self.setCursor(QCursor(Qt.OpenHandCursor)) else: self.setCursor(QCursor(Qt.ArrowCursor)) def leaveEvent(self, event): if self == self.main_window.enhanced_label: self.setCursor(QCursor(Qt.ArrowCursor)) def mousePressEvent(self, event): if self == self.main_window.enhanced_label: is_processing_active = self.main_window.is_processing and not self.main_window.is_paused if is_processing_active: return if event.button() == Qt.LeftButton and self.original_pixmap: self.dragging = True self.last_pos = event.position() self.setCursor(QCursor(Qt.ClosedHandCursor)) def mouseMoveEvent(self, event): if self == self.main_window.enhanced_label and self.dragging and self.original_pixmap: current_pos = event.position() delta = QPointF(current_pos - self.last_pos) self.last_pos = current_pos other_label = self.main_window.original_label new_offset = self.offset + delta pixmap_width, pixmap_height = self.original_pixmap.width(), self.original_pixmap.height() new_offset.setX(max(-(pixmap_width - self.width()), min(0, new_offset.x()))) new_offset.setY(max(-(pixmap_height - self.height()), min(0, new_offset.y()))) self.offset = new_offset if other_label.original_pixmap: other_label.offset = new_offset other_label.update_display() self.update_display() def mouseReleaseEvent(self, event): if self == self.main_window.enhanced_label and event.button() == Qt.LeftButton and self.dragging: self.dragging = False is_processing_active = self.main_window.is_processing and not self.main_window.is_paused if not is_processing_active: self.setCursor(QCursor(Qt.OpenHandCursor)) else: self.setCursor(QCursor(Qt.ArrowCursor)) def resizeEvent(self, event): self.update_display() super().resizeEvent(event) # --- HELPER FUNCTION FOR EXPORT --- def combine_frames_to_video(temp_dir, total_frames, fps, output_path, error_signal_emitter, progress_callback=None): if not temp_dir or not os.path.isdir(temp_dir): error_signal_emitter(f"临时目录未找到: {temp_dir}"); return False sample_frame_path = os.path.join(temp_dir, 'enhanced_frame_0.png') if not os.path.exists(sample_frame_path): return total_frames == 0 sample_frame = cv2.imread(sample_frame_path) if sample_frame is None: error_signal_emitter("无法读取样本帧以确定视频大小。"); return False height, width = sample_frame.shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) if not out.isOpened(): error_signal_emitter(f"无法打开VideoWriter以输出: {output_path}"); return False for frame_idx in range(total_frames): if progress_callback: progress_callback(frame_idx + 1, total_frames) frame_path = os.path.join(temp_dir, f"enhanced_frame_{frame_idx}.png") if os.path.exists(frame_path): frame = cv2.imread(frame_path) if frame is not None: out.write(frame) else: error_signal_emitter(f"损坏的超分帧 {frame_idx}"); out.release(); return False else: error_signal_emitter(f"缺少超分帧 {frame_idx}"); out.release(); return False out.release() return True # --- MainWindow Class --- class MainWindow(QMainWindow): export_error_signal = Signal(str) def __init__(self): super().__init__() self.setWindowTitle("视频超分 (单线程版)") self.setGeometry(100, 100, 1200, 700) self.video_cap = None self.current_frame_index = 0 self.processed_frame_count = 0 self.temp_output_dir = None self.processing_timer = QTimer(self) self.processing_timer.timeout.connect(self._process_next_frame) self.frame_cache = {} main_widget = QWidget() self.setCentralWidget(main_widget) layout = QVBoxLayout(main_widget) layout.setContentsMargins(24, 24, 24, 24) layout.setSpacing(16) control_layout = QHBoxLayout() control_layout.setSpacing(16) self.select_btn = QPushButton("选择视频") self.process_btn = QPushButton("超分视频") self.pause_btn = QPushButton("暂停") self.export_btn = QPushButton("导出视频") self.settings_btn = QPushButton("⚙设置") button_style = "QPushButton { background-color: #60A5FA; color: #FFFFFF; font-size: 16px; padding: 12px 24px; border-radius: 8px; } QPushButton:hover { background-color: #3B82F6; } QPushButton:disabled { background-color: #D1D5DB; color: #6B7280; }" for btn in [self.select_btn, self.process_btn, self.pause_btn, self.export_btn, self.settings_btn]: btn.setFont(QFont("Arial", 16)); btn.setStyleSheet(button_style); control_layout.addWidget(btn) self.process_btn.setEnabled(False); self.pause_btn.setEnabled(False); self.export_btn.setEnabled(False) layout.addLayout(control_layout) display_layout = QHBoxLayout(); display_layout.setSpacing(16) original_col_layout = QVBoxLayout(); original_col_layout.setSpacing(8) self.original_label = DraggableLabel(self, self) self.original_label.setText("原视频 - 选择视频"); self.original_label.setAlignment(Qt.AlignCenter) self.original_label.setMinimumSize(480, 360); self.original_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) original_col_layout.addWidget(self.original_label) self.original_info_label = QLabel("原视频\n分辨率: N/A | FPS: N/A"); self.original_info_label.setAlignment(Qt.AlignCenter) self.original_info_label.setStyleSheet("QLabel { background-color: #FFFFFF; color: #000000; font-size: 14px; padding: 8px; border-radius: 8px; }") original_col_layout.addWidget(self.original_info_label) display_layout.addLayout(original_col_layout, stretch=1) enhanced_col_layout = QVBoxLayout(); enhanced_col_layout.setSpacing(8) self.enhanced_label = DraggableLabel(self, self) self.enhanced_label.setText("超分视频 - 先处理或预览"); self.enhanced_label.setAlignment(Qt.AlignCenter) self.enhanced_label.setMinimumSize(480, 360); self.enhanced_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) enhanced_col_layout.addWidget(self.enhanced_label) self.enhanced_info_label = QLabel("超分视频\n分辨率: N/A | FPS: N/A"); self.enhanced_info_label.setAlignment(Qt.AlignCenter) self.enhanced_info_label.setStyleSheet("QLabel { background-color: #FFFFFF; color: #000000; font-size: 14px; padding: 8px; border-radius: 8px; }") enhanced_col_layout.addWidget(self.enhanced_info_label) display_layout.addLayout(enhanced_col_layout, stretch=1) layout.addLayout(display_layout, stretch=1) progress_layout = QHBoxLayout(); progress_layout.setSpacing(16) self.progress_bar = QProgressBar() self.progress_bar.setRange(0, 100); self.progress_bar.setValue(0); self.progress_bar.setTextVisible(True) self.progress_bar.setStyleSheet(""" QProgressBar { background-color: #E5E7EB; border-radius: 8px; height: 16px; text-align: center; color: #000000; font-size: 14px; } QProgressBar::chunk { background-color: #60A5FA; border-radius: 8px; } QProgressBar:disabled { background-color: #D1D5DB; color: #6B7280; } QProgressBar::chunk:disabled { background-color: #9CA3AF; } """) self.duration_label = QLabel("时长: 00:00 / 00:00") self.duration_label.setStyleSheet("QLabel { color: #000000; font-size: 14px; padding: 8px; }") progress_layout.addWidget(self.progress_bar); progress_layout.addWidget(self.duration_label) layout.addLayout(progress_layout) self.slider = QSlider(Qt.Horizontal); self.slider.setRange(0, 1000); self.slider.setEnabled(False) self.slider.setStyleSheet("QSlider::groove:horizontal { height: 8px; background: #E5E7EB; border-radius: 4px; } QSlider::handle:horizontal { background: #FFFFFF; border: 2px solid #60A5FA; width: 16px; height: 16px; margin: -6px 0; border-radius: 10px; } QSlider::sub-page:horizontal { background: #60A5FA; border-radius: 4px; }") layout.addWidget(self.slider) preview_control_layout = QHBoxLayout(alignment=Qt.AlignCenter) self.preview_btn = QPushButton("预览滑块位置帧") self.preview_btn.setFont(QFont("Arial", 16)); self.preview_btn.setStyleSheet(button_style); self.preview_btn.setEnabled(False) preview_control_layout.addWidget(self.preview_btn) layout.addLayout(preview_control_layout) self.status_label = QLabel("选择视频文件开始。", alignment=Qt.AlignCenter) self.status_label.setStyleSheet("QLabel { color: #000000; font-size: 16px; padding: 8px; }") layout.addWidget(self.status_label) self.select_btn.clicked.connect(self.select_video) self.process_btn.clicked.connect(self.process_video) self.pause_btn.clicked.connect(self.toggle_pause) self.export_btn.clicked.connect(self.export_video) self.settings_btn.clicked.connect(self.open_settings) self.slider.valueChanged.connect(self.update_duration_label_only) self.slider.sliderReleased.connect(self.preview_frame_at_slider) self.preview_btn.clicked.connect(self.preview_frame_at_slider) self.export_error_signal.connect(lambda msg: QMessageBox.critical(self, "导出错误", msg)) self.video_path, self.output_path = None, None self.total_frames, self.video_duration_sec, self.video_fps = 0, 0, 0 self.video_width, self.video_height, self.enhanced_width, self.enhanced_height = 0, 0, 0, 0 self.is_processing, self.is_paused = False, False try: initialize_global_model() except Exception as e: QMessageBox.critical(self, "初始化错误", f"模型初始化失败: {e}"); sys.exit(1) def _update_ui_for_processing_state(self): """Updates the enabled/disabled state of UI controls based on processing status.""" is_processing_active = self.is_processing and not self.is_paused is_video_loaded = bool(self.video_path) # Controls disabled ONLY during ACTIVE processing (enabled when paused) self.progress_bar.setEnabled(not is_processing_active) self.slider.setEnabled(is_video_loaded and not is_processing_active) self.preview_btn.setEnabled(is_video_loaded and not is_processing_active) # MODIFICATION: These buttons are now also only disabled during ACTIVE processing. self.select_btn.setEnabled(not is_processing_active) self.settings_btn.setEnabled(not is_processing_active) # Controls disabled WHENEVER a process exists (active or paused) # Prevents starting a new process while one is paused. self.process_btn.setEnabled(is_video_loaded and not self.is_processing) # Pause button is special: only enabled when a process exists self.pause_btn.setEnabled(self.is_processing) def open_settings(self): dialog = QDialog(self) dialog.setWindowTitle("设置"); dialog.setMinimumWidth(400) layout = QVBoxLayout(dialog); settings_layout = QHBoxLayout(); settings_layout.setSpacing(16) title_label = QLabel("模型选择"); title_label.setStyleSheet("QLabel { font-size: 16px; font-weight: bold; color: #000000; }"); title_label.setFixedWidth(100) settings_layout.addWidget(title_label) group = QGroupBox(); group_layout = QVBoxLayout(group); group_layout.setContentsMargins(10, 10, 10, 10) performance_radio = QRadioButton("性能模式 (Real-ESRGAN-x4v3)") quality_radio = QRadioButton("高品质模式 (Real-ESRGAN-x4plus)") if CURRENT_MODEL_NAME == "real_esrgan_general_x4v3": performance_radio.setChecked(True) else: quality_radio.setChecked(True) group_layout.addWidget(performance_radio); group_layout.addWidget(quality_radio) settings_layout.addWidget(group); layout.addLayout(settings_layout) ok_btn = QPushButton("确定") ok_btn.setStyleSheet("QPushButton { background-color: #60A5FA; color: #FFFFFF; font-size: 16px; padding: 8px 16px; border-radius: 8px; } QPushButton:hover { background-color: #3B82F6; }") ok_btn.clicked.connect(lambda: self.update_model(performance_radio.isChecked(), dialog)) layout.addWidget(ok_btn, alignment=Qt.AlignRight); dialog.exec() def update_model(self, is_performance, dialog): global CURRENT_MODEL_NAME # If a process was paused, changing the model should stop it. if self.is_processing: self.stop_processing() new_model_name = "real_esrgan_general_x4v3" if is_performance else "real_esrgan_x4plus" if new_model_name != CURRENT_MODEL_NAME: self.frame_cache.clear() if self.temp_output_dir and os.path.exists(self.temp_output_dir): shutil.rmtree(self.temp_output_dir, ignore_errors=True) self.temp_output_dir = None self.enhanced_label.setText("超分视频 - 先处理或预览"); self.enhanced_info_label.setText("超分视频\n分辨率: N/A | FPS: N/A") CURRENT_MODEL_NAME = new_model_name try: initialize_global_model(); self.status_label.setText(f"已切换到{'性能模式' if is_performance else '高品质模式'}") except Exception as e: QMessageBox.critical(self, "模型切换错误", f"模型切换失败: {e}"); self.status_label.setText("模型切换失败"); return dialog.close() def select_video(self): # If a process was paused, selecting a new video should stop it. if self.is_processing: self.stop_processing() self.reset_ui_state() file_path, _ = QFileDialog.getOpenFileName(self, "选择视频", "", "视频文件 (*.mp4 *.avi *.mov *.mkv)") if file_path: cap = cv2.VideoCapture(file_path) if not cap.isOpened(): self.status_label.setText("错误: 无法打开选择的视频。"); return self.video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)); self.video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)); self.video_fps = cap.get(cv2.CAP_PROP_FPS) cap.release() if not (self.video_height > 0 and self.video_width > 0 and self.total_frames > 0 and self.video_fps > 0): self.status_label.setText("错误: 无法读取视频属性。"); self.reset_ui_state(); return if self.video_height > 720: QMessageBox.warning(self, "分辨率过高", "请选择高度为720p或更低的视频。"); self.reset_ui_state(); return self.video_path = file_path self.video_duration_sec = self.total_frames / self.video_fps if self.video_fps > 0 else 0 self.enhanced_width, self.enhanced_height = self.video_width * 4, self.video_height * 4 self.slider.setValue(0); self.status_label.setText(f"已选择: {os.path.basename(file_path)}") self.original_info_label.setText(f"原视频\n分辨率: {self.video_width}x{self.video_height} | FPS: {self.video_fps:.2f}") self.enhanced_info_label.setText("超分视频\n分辨率: N/A | FPS: N/A"); self.enhanced_label.setText("超分视频 - 先处理或预览") self.update_duration_label(); self.show_specific_frame(0) self.process_btn.setEnabled(True); self.preview_btn.setEnabled(True); self.slider.setEnabled(True) def reset_ui_state(self): self.video_path = self.output_path = None self.total_frames = self.video_duration_sec = self.video_fps = 0 self.video_width = self.video_height = self.enhanced_width = self.enhanced_height = 0 self.is_processing = self.is_paused = False self.processed_frame_count = 0; self.frame_cache.clear() if self.temp_output_dir and os.path.exists(self.temp_output_dir): shutil.rmtree(self.temp_output_dir, ignore_errors=True) self.temp_output_dir = None for label in [self.original_label, self.enhanced_label]: label.set_pixmap(QPixmap()); label.offset = QPointF(0.0, 0.0) self.original_label.setText("原视频 - 选择视频"); self.enhanced_label.setText("超分视频 - 先处理或预览") self.original_info_label.setText("原视频\n分辨率: N/A | FPS: N/A"); self.enhanced_info_label.setText("超分视频\n分辨率: N/A | FPS: N/A") self.status_label.setText("选择视频文件开始。"); self.progress_bar.setValue(0); self.slider.setValue(0); self.slider.setEnabled(False) self.duration_label.setText("时长: 00:00 / 00:00") for btn, state in [(self.process_btn, False), (self.preview_btn, False), (self.pause_btn, False), (self.export_btn, False)]: btn.setEnabled(state) self.pause_btn.setText("暂停") self.progress_bar.setEnabled(True) def format_time(self, seconds): mins, secs = divmod(int(seconds), 60); return f"{mins:02d}:{secs:02d}" def update_duration_label(self): current_pos_sec = (self.slider.value() / self.slider.maximum()) * self.video_duration_sec self.duration_label.setText(f"时长: {self.format_time(current_pos_sec)} / {self.format_time(self.video_duration_sec)}") def update_duration_label_only(self, value): if self.video_path: self.update_duration_label() def display_frame(self, frame_rgb, label_widget, current_offset=None, target_scale=1.0): if frame_rgb is None or frame_rgb.size == 0: label_widget.set_pixmap(QPixmap()); label_widget.setText("无帧数据"); return h, w, ch = frame_rgb.shape qimg = QImage(frame_rgb.data, w, h, ch * w, QImage.Format_RGB888) pixmap = QPixmap.fromImage(qimg) if target_scale != 1.0: pixmap = pixmap.scaled(int(w * target_scale), int(h * target_scale), Qt.KeepAspectRatio, Qt.SmoothTransformation) offset = QPointF((label_widget.width() - pixmap.width()) / 2, (label_widget.height() - pixmap.height()) / 2) if current_offset is None else current_offset label_widget.set_pixmap(pixmap, new_offset=offset); label_widget.update_display() def show_specific_frame(self, frame_index): if not self.video_path or self.total_frames == 0: return frame_index = max(0, min(frame_index, self.total_frames - 1)) current_offset = self.original_label.offset if self.original_label.offset != QPointF(0.0, 0.0) else None if frame_index in self.frame_cache: frame_rgb = self.frame_cache[frame_index] self.display_frame(frame_rgb, self.original_label, current_offset, target_scale=4.0) else: cap = cv2.VideoCapture(self.video_path) if cap.isOpened(): cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index); ret, frame = cap.read(); cap.release() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB); self.frame_cache[frame_index] = frame_rgb.copy() self.display_frame(frame_rgb, self.original_label, current_offset, target_scale=4.0) else: self.original_label.setText("加载帧错误") self.enhanced_label.offset = self.original_label.offset if self.temp_output_dir: frame_path = os.path.join(self.temp_output_dir, f"enhanced_frame_{frame_index}.png") if os.path.exists(frame_path): enh_rgb = cv2.cvtColor(cv2.imread(frame_path), cv2.COLOR_BGR2RGB) self.display_frame(enh_rgb, self.enhanced_label, self.original_label.offset) self.enhanced_info_label.setText(f"超分视频 (已处理)\n分辨率: {self.enhanced_width}x{self.enhanced_height} | FPS: N/A") self.original_label.update_display(); self.enhanced_label.update_display() def preview_frame_at_slider(self): if not self.video_path: return frame_index = int((self.slider.value() / self.slider.maximum()) * (self.total_frames - 1)) self.show_specific_frame(frame_index) original_frame_rgb = self.frame_cache.get(frame_index) if original_frame_rgb is not None: self.enhanced_label.setText("超分预览中..."); QApplication.processEvents() try: enhanced_frame_rgb = Inference(original_frame_rgb) self.display_frame(enhanced_frame_rgb, self.enhanced_label, self.original_label.offset) self.enhanced_info_label.setText(f"超分视频 (预览)\n分辨率: {self.enhanced_width}x{self.enhanced_height} | FPS: N/A") except Exception as e: self.enhanced_label.setText("预览超分失败"); self.enhanced_info_label.setText("超分视频\n分辨率: 失败 | FPS: N/A"); print(f"预览错误: {e}") QApplication.processEvents() def process_video(self): if not self.video_path or self.is_processing: return self.stop_processing() self.video_cap = cv2.VideoCapture(self.video_path) if not self.video_cap.isOpened(): self.processing_error(f"无法打开视频文件: {self.video_path}"); return self.temp_output_dir = tempfile.mkdtemp() self.current_frame_index = 0; self.processed_frame_count = 0 self.is_processing = True; self.is_paused = False self.pause_btn.setText("暂停"); self.status_label.setText("处理中..."); self.enhanced_label.setText("处理中...") self.progress_bar.setValue(0) self._update_ui_for_processing_state() self.processing_timer.start(0) def _process_next_frame(self): if not self.is_processing or self.is_paused: return if self.current_frame_index >= self.total_frames: self.processing_finished(); return try: ret, frame = self.video_cap.read() if not ret: self.processing_finished(); return frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) enhanced_frame_rgb = Inference(frame_rgb) frame_path = os.path.join(self.temp_output_dir, f"enhanced_frame_{self.current_frame_index}.png") frame_bgr = cv2.cvtColor(enhanced_frame_rgb, cv2.COLOR_RGB2BGR) if not cv2.imwrite(frame_path, frame_bgr): self.processing_error(f"保存帧失败 {self.current_frame_index}: 写入错误"); return self.current_frame_index += 1 progress = int((self.current_frame_index / self.total_frames) * 100) self.update_progress(progress) QApplication.processEvents() except Exception as e: import traceback; traceback.print_exc(); self.processing_error(f"处理帧 {self.current_frame_index} 时发生异常: {e}") def toggle_pause(self): if not self.is_processing: return self.is_paused = not self.is_paused if self.is_paused: self.pause_btn.setText("恢复"); self.status_label.setText("已暂停") else: self.pause_btn.setText("暂停"); self.status_label.setText(f"处理中... {self.progress_bar.value()}%") self._update_ui_for_processing_state() def stop_processing(self): self.processing_timer.stop() if self.video_cap: self.video_cap.release(); self.video_cap = None self.is_processing = self.is_paused = False self.pause_btn.setText("暂停") self._update_ui_for_processing_state() self.progress_bar.setEnabled(True) def update_progress(self, value): self.progress_bar.setValue(value) if not self.is_paused: self.status_label.setText(f"处理中... {value}%") def processing_error(self, error_message): QMessageBox.critical(self, "处理错误", f"发生错误:\n{error_message}") self.stop_processing(); self.status_label.setText(f"处理失败: {error_message}"); self.enhanced_label.setText("处理失败") self.progress_bar.setValue(0) def processing_finished(self): self.processed_frame_count = self.current_frame_index self.stop_processing() self.status_label.setText("处理完成。准备导出。") self.enhanced_info_label.setText(f"超分视频\n分辨率: {self.enhanced_width}x{self.enhanced_height} | FPS: N/A") self.progress_bar.setValue(100); self.export_btn.setEnabled(True); self.slider.setValue(0); self.show_specific_frame(0) def export_video(self): if not self.temp_output_dir or not os.path.isdir(self.temp_output_dir): QMessageBox.warning(self, "导出错误", "未找到已处理帧以导出。请先处理视频。"); return base, _ = os.path.splitext(os.path.basename(self.video_path)); default_filename = f"{base}_enhanced.mp4" destination_path, _ = QFileDialog.getSaveFileName(self, "导出超分视频", default_filename, "MP4 视频文件 (*.mp4)") if destination_path: self.status_label.setText(f"准备导出..."); self.progress_bar.setValue(0); QApplication.processEvents() def export_progress(current, total): if total > 0: progress = int((current / total) * 100) self.progress_bar.setValue(progress); self.status_label.setText(f"正在导出帧: {current} / {total}") QApplication.processEvents() frames_to_export = self.processed_frame_count success = combine_frames_to_video(self.temp_output_dir, frames_to_export, self.video_fps, destination_path, self.export_error_signal.emit, progress_callback=export_progress) if success: self.progress_bar.setValue(100); self.status_label.setText(f"视频导出成功!") QMessageBox.information(self, "导出成功", f"超分视频已保存到:\n{destination_path}") else: self.status_label.setText("导出失败。"); self.progress_bar.setValue(0) def closeEvent(self, event): self.stop_processing() if self.temp_output_dir and os.path.exists(self.temp_output_dir): shutil.rmtree(self.temp_output_dir, ignore_errors=True) release_global_model(); event.accept() if __name__ == '__main__': app = QApplication(sys.argv) app.setStyleSheet("QMainWindow { background-color: #FFFFFF; }") window = MainWindow() window.showMaximized() sys.exit(app.exec())