import numpy as np
import librosa
from scipy.signal import butter, sawtooth, sosfilt
import sys
import math
import threading
import multiprocessing as mp
from queue import Queue

def bandpass_filter(data, lowcut, highcut, fsinput, order=6):
    sos = butter(order, [lowcut, highcut], btype='bandpass', output='sos', fs=fsinput)
    return sosfilt(sos, data)

def lowpass_filter(data, cutoff, fsinput, order=4):
    sos = butter(order, cutoff, btype='lowpass', output='sos', fs=fsinput)
    return sosfilt(sos, data)

def lightblade(finput,bladeshape,sr,vatype):
    if vatype==True:
        return sr * (finput - bladeshape)
    else:
        return finput

def generate_optical_track(input_wav, output_bin, height, width, fps, num_threads=None,slewrate=20,vatype=True):
    # 自动检测CPU核心数
    if num_threads is None:
        num_threads = mp.cpu_count()
    # 加载并预处理音频
    data, fs = librosa.load(input_wav, sr=None, mono=False)
    # 多声道判断
    channel = 1
    if len(data.shape) == 1:
        data=[data]
    else:
        channel = data.shape[0]
        
    print(f"{num_threads}个线程")

    # 掐头去尾参数计算
    rt = .98
    rl = math.floor(len(data[0])*((1-rt)/2))
    rr = math.floor(len(data[0])*(1-((1-rt)/2)))
    print(rl,rr)

    # 需要对所有样本取最大值，对每个通道除以同一个系数，以保持立体声声像
    # 数据归一化
    data /= np.max(np.abs(data))
    # 滤波处理
    highcut = (height * fps) / 4
    if highcut >= fs / 2:
        highcut = (fs / 2) - 1
    
    filtered = bandpass_filter(data, 20, highcut, fs)
 
    ct = filtered[:,rl:rr]
    ftmin = np.min(ct)
    ftmax = np.max(ct)
    print(ftmin,ftmax)
    filtered = filtered/np.max([ftmax,-ftmin])
    
    # 波形叠加自身包络，将交变波形转为正极性脉冲，并使光刃宽度最小化

    #选择包络来源，可选data、filtered
    #使用data作包络源会产生以下现象：输入音频带宽高于生成音轨可承载带宽时，高频部分会变成一条无波动的宽线。
    #使用filtered，因滤除了高频，可消除上述宽线。
    envelopesource=filtered

    envelope = lowpass_filter(np.abs(envelopesource), 70, fs)

    ct = envelope[:,rl:rr]
    evmin = np.min(ct)
    evmax = np.max(ct)
    print(evmin,evmax)

    envelope = (envelope-evmin)/(evmax-evmin)
    
    filtered = (filtered + envelope)*.5

    #调整光刃宽度（信号上下界）
    rrmin=-.1
    rrmax=.85
    filtered = (filtered-rrmin)/(rrmax-rrmin)
    
    # 预计算三角波模板（每行）
    t_line = np.linspace(0, 1, width, endpoint=False)
    triangle_line = (sawtooth(2 * np.pi * t_line + np.pi, 0.5) + 1) / 2
    
    # 创建线程安全队列
    frame_queue = Queue()
    result_queue = Queue()
    
    # 生成三角波模板（每帧）
    triangle_frame_template = np.tile(triangle_line, height * channel)

    # 将输出图像视为每帧宽度width、高度height*channel，每行声道轮换
    # 计算总帧数和每帧采样数
    samples_per_frame = height * width
    total_time = len(filtered[0]) / fs
    n_frames = math.ceil(total_time * fps)

    # 创建占位文件“答题卡”
    frame_bytes = height * width * channel
    total_bytes = n_frames*frame_bytes
    print(n_frames,'*',frame_bytes,'~=',round(total_bytes/1048576,1),'MB')
    with open(output_bin, 'wb') as f:
        f.truncate(total_bytes)

    # 创建并启动工作线程
    threads = []
    for _ in range(num_threads):
        t = threading.Thread(
            target=multichannel_frame_processing_worker,
            args=(filtered,fs,fps,samples_per_frame,channel,triangle_frame_template,frame_queue,output_bin,height,width,slewrate,vatype)
        )
        t.daemon = True
        t.start()
        threads.append(t)
        
    # 将帧索引放入队列供工作线程处理
    for frame_idx in range(n_frames):
        print(frame_idx,"/",n_frames)
        frame_queue.put(frame_idx)
        
    # 添加停止信号
    for _ in range(num_threads):
        frame_queue.put(None)
        
    # 等待所有工作线程完成
    frame_queue.join()
    
    print(f"{channel}个声道，输出视频：高度{height}、宽度{width*channel}、帧率{fps}")
    print(f"ffmpeg命令参考：-f rawvideo -pix_fmt gray -s {width*channel}x{height} -r {fps}")

def multichannel_frame_processing_worker(filtered,fs,fps,samples_per_frame,channel,triangle_frame_template,frame_queue,output_bin,height,width,slewrate,vatype):

    fs_fps_ratio = fs / fps
    frame_bytes = height * width * channel

    with open(output_bin,'r+b') as fout:
        while True:
            frame_idx = frame_queue.get()

            if frame_idx is None:
                frame_queue.task_done()
                break
            
            j = np.arange(samples_per_frame)
            pos = (frame_idx + j / samples_per_frame) * fs_fps_ratio
            frame_resampled=np.zeros((channel,samples_per_frame))
            mask_valid=(pos>=0)&(pos<len(filtered[0])-1)
            valid_pos=pos[mask_valid]

            if len(valid_pos)>0:
                idx0=valid_pos.astype(int)
                idx1=idx0+1
                frac=valid_pos-idx0

                for ch in range(channel):
                    chan_data=filtered[ch]
                    frame_resampled[ch,mask_valid]=(
                        (1-frac)*chan_data[idx0]
                        +
                        frac*chan_data[idx1]
                    )

            reshaped=frame_resampled.reshape((channel,height,width))
            transposed=reshaped.transpose(1,0,2)
            frame_resampled_interleaved=(transposed.reshape(-1))

            output_frame=lightblade(
                frame_resampled_interleaved,
                triangle_frame_template,
                slewrate,
                vatype
            )

            output_frame=np.clip(output_frame,0,1)
            binary_frame=(output_frame*255).astype(np.uint8)

            # “填空题”：直接写对应位置
            fout.seek(frame_idx*frame_bytes)
            fout.write(binary_frame.tobytes())
            frame_queue.task_done()

if __name__ == '__main__':
    if len(sys.argv) < 6:
        print("光学音轨仿真程序")
        print("使用方法：python opttkgen.py 输入文件名（使用librosa库，不局限于wav格式） 输出文件名（RAWVIDEO，Gray（uint8）） 每帧图像高度 每声道宽度 视频帧率 [线程数]")
        print("线程数不输入时默认为CPU线程数")
        sys.exit(1)
    
    input_wav = sys.argv[1]
    output_bin = sys.argv[2]
    height = int(sys.argv[3])
    width = int(sys.argv[4])
    fps = int(sys.argv[5])
    
    # 默认使用使用CPU核心数，如果提供了第6个参数则使用它
    num_threads = None
    if len(sys.argv) > 6:
        num_threads = int(sys.argv[6])

    # 平滑图案边缘：比较器的摆率，数值越小，音轨图案的边缘越模糊。建议每声道宽度为64时设置摆率为20，音轨越宽，设置的值应越大。
    slewrate=20

    # 音轨类型选择：变积式为True，变密式为False。
    VA=True

    print("光学音轨仿真程序")
    print("作者：TCJY（人类）、DeepSeek、ChatGPT、Gemini")
    
    generate_optical_track(input_wav, output_bin, height, width, fps, num_threads,slewrate,VA)
    print(f"输出已写入 {output_bin}。")