# Predict.jl

export predict
export get_images_from_audio

using WAV,
    DSP, Images, ThreadsX, Dates, DataFrames, CSV, Flux, CUDA, Metalhead, JLD2, JSON, FLAC, Glob, PerceptualColourMaps
import Base: length, getindex

##Dependency, duplicated from Utility
function _resample_to_16000hz(signal, freq)
    signal = DSP.resample(signal, 16000.0f0 / freq; dims=1)
    freq = 16000
    return signal, freq
end

function _resample_to_8000hz(signal, freq)
    signal = DSP.resample(signal, 8000.0f0 / freq; dims=1)
    freq = 8000
    return signal, freq
end

##Dependency, duplicated from Clips
function _get_image_from_sample(sample, f) #sample::Vector{Float64}
    S = DSP.spectrogram(sample, 400, 2; fs=convert(Int, f))
    i = S.power
    if minimum(i) == 0.0
        l = i |> vec |> unique |> sort
        replace!(i, 0.0 => l[2])
    end
    image =
        #! format: off
        DSP.pow2db.(i) |>
        x -> x .+ abs(minimum(x)) |>
        x -> x ./ maximum(x) |>
        x -> reverse(x, dims = 1) |>
        x -> PerceptualColourMaps.applycolourmap(x, cmap("L4")) |>
        #x -> RGB.(x) |> 
        x -> imresize(x, 224, 224) |>
        x -> Float32.(x)
        #! format: on
    return image
end

"""
predict(glob_pattern::String, model::String)

This function takes a glob pattern for folders (or a vector of folders) to run over, and a model path. It saves results in a csv in each folder, similar to opensoundscape

Args:

•  glob pattern (folder/) or a vector of folders
•  model path

Returns: Nothing - This function saves csv files.

I use this function to find kiwi from new data gathered on a trip. And to predict D/F/M/N for images clipped from primary detections.

It works on both audio (wav or flac) and png images.

Note:
From Pomona-3/Pomona-3/
julia +1.10 -t 4
Dont forget temp environment: ] activate --temp

Use like:
using SkraakML, Glob

glob_pattern = "*/*/"
model = "/media/david/SSD2/PrimaryDataset/model_K1-9_original_set_CPU_epoch-7-0.9924-2024-03-05.jld2"

glob_pattern = "Clips_2024-10-21/"
model = "/media/david/SSD1/Clips/model_DFMN1-5_CPU_epoch-18-0.9132-2024-01-29.jld2"

glob_pattern = "Clips_Kahurangi_2024-10-25"
glob_pattern = "Clips_Cobb_2024-10-25/"
glob_pattern = "Clips_2024-10-28_MT6"
glob_pattern = "Pomona-*/Pomona/Clips_2024-11-27/"
model = "/media/david/SSD2/Secondary_Models/DFMN_Inge/model_DFMN1-5_CPU_epoch-9-0.9737-2024-10-25.jld2"
model = "/media/david/SSD2/DFMN_Pomona/model_DFMN1-5_Pomona2_CPU_epoch-2-0.9459-2025-01-19.jld2"
model = "/media/david/SSD2/Secondary_Models/LSK/model_GSK_LSK_DFM_FT_IngeDFMN_1-5_1-0_CPU_epoch-9-0.9745-2025-01-13.jld2"
model = "/media/david/SSD2/Secondary_Models/DFMN_Pomona/model_DFMN1-5_Pomona3_CPU_epoch-18-0.9785-2025-03-02.jld2"
predict(glob_pattern, model)
folders=glob("NEW*/*/Clips_*/Kiwi/")
predict(folders, model)
"""

function predict(glob_pattern::String, model::String)
    model = load_model_pred(model) |> device
    folders = Glob.glob(glob_pattern)
    @info "Folders: $folders"
    for folder in folders
        @info "Working on: $folder"
        predict_folder(folder, model)
    end
end

function predict(folders::Vector{String}, model::String)
    model = load_model_pred(model) |> device
    @info "Folders: $folders"
    for folder in folders
        @info "Working on: $folder"
        predict_folder(folder, model)
    end
end

function predict(folders::Vector{String}, model::String, labels::Dict)
    model = load_model_pred(model) |> device
    @info "Folders: $folders"
    for folder in folders
        @info "Working on: $folder"
        predict_folder(folder, model, labels)
    end
end

#~~~~~ The guts ~~~~~#
# see load_model() from train, different input types
function load_model_pred(model_path::String)
    model_state = JLD2.load(model_path, "model_state")
    model_classes = length(model_state[1][2][1][3][2])
    @info "Model classes: $model_classes"
    f = Metalhead.ResNet(18, pretrain=false).layers
    l = Flux.Chain(AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => model_classes))
    model = Flux.Chain(f[1], l)
    Flux.loadmodel!(model, model_state)
    return model
end

#=
function load_bson(model_path::String)
    BSON.@load model_path model
end
=#

function predict_folder(folder::String, model, labels::Dict=Dict())
    data_files = Glob.glob("$folder/*.data")
    if !isempty(data_files) && !isempty(labels)
        predict_avianz_folder(data_files, model, folder, labels)
        return
    end
    wav = Glob.glob("$folder/*.[W,w][A,a][V,v]")
    flac = Glob.glob("$folder/*.flac")
    audio_files = vcat(wav, flac) #if wav and flac both present will predict on all
    png_files = Glob.glob("$folder/*.png")
    #it will predict on images when both images and audio present
    if isempty(png_files)
        length(audio_files) > 0 ? predict_audio_folder(audio_files, model, folder) :
        @info "No png, flac, wav, WAV files present in $folder"
    else
        predict_image_folder(png_files, model, folder)
    end
end

device = CUDA.functional() ? gpu : cpu

# Predict from png images
struct PredictImageContainer{T<:Vector}
    img::T
end

length(data::PredictImageContainer) = length(data.img)

function getindex(data::PredictImageContainer{Vector{String}}, idx::Int)
    path = data.img[idx]
    img =
        #! format: off
        Images.load(path) |>
        x -> Images.imresize(x, 224, 224)|>
        x -> Images.RGB.(x) |>
        x -> collect(channelview(float32.(x))) |>
        x -> permutedims(x, (3, 2, 1))
        #! format: on
    return img, path
end

function predict_image_folder(png_files::Vector{String}, model, folder::String)
    l = length(png_files)
    @assert (l > 0) "No png files present in $folder"
    @info "$(l) png_files in $folder"
    save_path = "$folder/preds-$(today()).csv"
    loader = png_loader(png_files)
    @time preds, files = predict_pngs(model, loader)
    f = split.(files, "/") |> x -> last.(x)
    df = DataFrames.DataFrame(file=f, label=preds)
    CSV.write("$save_path", df)
end

function png_loader(png_files::Vector{String})
    loader = Flux.DataLoader(
        PredictImageContainer(png_files);
        batchsize=64,
        collate=true,
        parallel=true,
    )
    device == gpu ? loader = CuIterator(loader) : nothing
    return loader
end

function predict_pngs(m, d)
    @info "Predicting..."
    pred = []
    path = []
    for (x, pth) in d
        p = Flux.onecold(m(x))
        append!(pred, p)
        append!(path, pth)
    end
    return pred, path
end

# Predict from audio files
function predict_audio_folder(audio_files::Vector{String}, model, folder::String)
    l = length(audio_files)
    @assert (l > 0) "No wav or flac audio files present in $folder"
    @info "$(l) audio_files in $folder"
    df = DataFrames.DataFrame(
        file=String[],
        start_time=Float64[],
        end_time=Float64[],
        label=Int[],
    )
    save_path = "$folder/preds-$(today()).csv"
    CSV.write("$save_path", df)
    for file in audio_files
        df = predict_audio_file(file, model)
        CSV.write("$save_path", df, append=true)
    end
end

function predict_audio_file(file::String, model)
    #check form of opensoundscape preds.csv and needed by my make_clips
    @info "File: $file"
    @time data = audio_loader(file)
    pred = []
    time = []
    @time for (x, t) in data
        p = Flux.onecold(model(x))
        append!(pred, p)
        append!(time, t)
    end
    f = (repeat(["$file"], length(time)))
    df = DataFrames.DataFrame(
        :file => f,
        :start_time => first.(time),
        :end_time => last.(time),
        :label => pred,
    )
    sort!(df)
    return df
end

function audio_loader(file::String, increment::Int=5, divisor::Int=2)
    raw_images, n_samples = get_images_from_audio(file::String, increment, divisor)
    images = reshape_images(raw_images, n_samples)

    # Start time and end time for each 5s audio clip, in seconds relative to the start of the file.
    start_time = 0:(increment/divisor):(n_samples-1)*(increment/divisor)
    end_time = increment:(increment/divisor):(n_samples+1)*(increment/divisor)
    time = collect(zip(start_time, end_time))

    loader = Flux.DataLoader((images, time), batchsize=n_samples, shuffle=false)
    device == gpu ? loader = CuIterator(loader) : nothing #check this works with gpu
    return loader
end

function reshape_images(raw_images, n_samples)
    images =
        #! format: off
        hcat(raw_images...) |>
        x -> reshape(x, (224, 224, 3, n_samples))
        #! format: on
    return images
end

#= not needed
function get_image_for_inference(sample, f)
    image =
        #! format: off
        _get_image_from_sample(sample, f) |>
        # x -> collect(channelview(float32.(x))) |>
        x -> permutedims(x, (3, 2, 1))
        #! format: on
    return image
end
=#

# need to change divisor to a overlap fraction, chech interaction with audioloader()
# if divisor is 0, then no overlap atm
function get_images_from_audio(file::String, increment::Int=5, divisor::Int=2) #5s sample, 2.5s hop
    signal, freq = load_audio_file(file)
    if freq > 16000
        signal, freq = _resample_to_16000hz(signal, freq)
    end
    f = convert(Int, freq)
    inc = increment * f
    #hop = f * increment ÷ divisor #need guarunteed Int, maybe not anymore, refactor
    hop = 0 #f * increment / divisor |> x -> x == Inf ? 0 : trunc(Int, x)
    split_signal = DSP.arraysplit(signal[:, 1], inc, hop)
    raw_images = ThreadsX.map(x -> _get_image_from_sample(x, f), split_signal)
    n_samples = length(raw_images)
    return raw_images, n_samples
end

function load_audio_file(file::String)
    ext = split(file, ".")[end]
    @assert ext in ["WAV", "wav", "flac"] "Unsupported audio file type, requires wav or flac."
    if ext in ["WAV", "wav"]
        signal, freq = WAV.wavread(file)
    else
        signal, freq = load(file)
    end
    @assert !isempty(signal[:, 1]) "$file seems to be empty, could it be corrupted?\nYou could delete it, or replace it with a known\ngood version from SD card or backup."
    return signal, freq
end

############### AviaNZ .data file support ################

function predict_avianz_folder(data_files::Vector{String}, model, folder::String, labels::Dict)
    @info "$(length(data_files)) .data files in $folder"
    for data_file in data_files
        predict_avianz_file(data_file, model, labels)
    end
end

function predict_avianz_file(data_file::String, model, labels::Dict)
    @info "Processing: $data_file"
    data = JSON.parsefile(data_file)

    wav_file = data_file[1:end-5] # strip .data suffix
    if !isfile(wav_file)
        @warn "Audio file not found: $wav_file, skipping"
        return
    end

    signal, freq = load_audio_file(wav_file)
    if freq != 8000
        signal, freq = _resample_to_8000hz(signal, freq)
    end
    f = convert(Int, freq)

    # Collect images and references to Kiwi segment labels
    raw_images = []
    kiwi_labels = [] # references to the label dicts to update

    # Elements 2..N are segments (index 1 is metadata in AviaNZ format)
    for i in 2:length(data)
        segment = data[i]
        # Each segment is [start_time, end_time, low_freq, high_freq, [labels...]]
        segment_labels = segment[5]
        for label in segment_labels
            if isa(label, Dict) && haskey(label, "species") && label["species"] == "Kiwi"
                start_time = segment[1]
                end_time = segment[2]
                start_sample = max(1, round(Int, start_time * f) + 1)
                end_sample = min(size(signal, 1), round(Int, end_time * f))
                if end_sample <= start_sample
                    @warn "Empty segment at $start_time-$end_time in $data_file, skipping"
                    continue
                end
                sample = signal[start_sample:end_sample, 1]
                image = _get_image_from_sample(sample, f)
                push!(raw_images, image)
                push!(kiwi_labels, label)
            end
        end
    end

    if isempty(raw_images)
        @info "No Kiwi segments in $data_file"
        return
    end

    @info "$(length(raw_images)) Kiwi segments in $data_file"

    # Create DataLoader and predict
    n_samples = length(raw_images)
    loader = avianz_loader(raw_images, n_samples)
    preds = []
    for x in loader
        p = Flux.onecold(model(x))
        append!(preds, p)
    end

    # Update species/calltype in each Kiwi segment's label
    for (i, label) in enumerate(kiwi_labels)
        pred = labels[preds[i]]
        #=      if pred == "Don't Know"
                    label["species"] = "Don't Know"
                    delete!(label, "calltype")
                else
                    label["calltype"] = pred
                end
        =#
        # For LSK/GSK
        if pred == "LSK"
            label["species"] = "LSK"
        else
            label["species"] = "GSK"
        end

    end

    # Write modified JSON back to .data file
    open(data_file, "w") do io
        JSON.print(io, data, 2)
    end
    @info "Updated $data_file"
end

function avianz_loader(raw_images::Vector, n_samples::Int)
    #! format: off
    processed = map(raw_images) do img
        colorview(RGB, permutedims(img, (3, 1, 2))) |>
        x -> Images.RGB.(x) |>
        x -> collect(channelview(float32.(x))) |>
        x -> permutedims(x, (3, 2, 1))
    end
    #! format: on
    images = cat(processed..., dims=4)
    loader = Flux.DataLoader(images, batchsize=n_samples, shuffle=false)
    device == gpu ? loader = CuIterator(loader) : nothing
    return loader
end

############### Temp test functions — delete after testing ################

function get_one(audio_file::String, start_time::Float64, end_time::Float64)
    signal, freq = load_audio_file(audio_file)
    if freq != 8000
        signal, freq = _resample_to_8000hz(signal, freq)
    end
    f = convert(Int, freq)
    s = max(1, round(Int, start_time * f) + 1)
    e = min(size(signal, 1), round(Int, end_time * f))
    sample = signal[s:e, 1]
    img = _get_image_from_sample(sample, f)
    return img
end

function round_trip(img)
    #! format: off
    colorview(RGB, permutedims(img, (3, 1, 2))) |>
    x -> Images.RGB.(x) |>
    x -> collect(channelview(float32.(x))) |>
    x -> permutedims(x, (3, 2, 1))
    #! format: on
end

# SkraakML.compare("tx51_LISTENING_20260221_203004.WAV", 5.0, 35.0)
function compare(audio_file::String, start_time::Float64, end_time::Float64)
    img = get_one(audio_file, start_time, end_time)
    rt = round_trip(img)
    @info "Sizes: raw=$(size(img)) roundtrip=$(size(rt))"
    @info "Types: raw=$(eltype(img)) roundtrip=$(eltype(rt))"
    @info "Equal: $(img == rt)"
    @info "Max difference: $(maximum(abs.(img .- rt)))"
    @info "Raw range: min=$(minimum(img)) max=$(maximum(img))"
    @info "Roundtrip range: min=$(minimum(rt)) max=$(maximum(rt))"
    @info "Values > 1.0 in raw: $(sum(img .> 1.0))"
    @info "Values < 0.0 in raw: $(sum(img .< 0.0))"
    # Check per-channel: does raw[:,:,1] match rt[:,:,1] or a different channel?
    for i in 1:3
        for j in 1:3
            d = maximum(abs.(img[:, :, i] .- rt[:, :, j]))
            @info "raw channel $i vs roundtrip channel $j: max diff = $d"
        end
    end
end

############### PYTHON Opensoundscape ################
#=
# Python 3.8.12, opensoundscape 0.7.1
# Dont forget conda activate opensoundscape
# Dont forget to modify file names and glob pattern
# Run script in Pomona-2, hard code trip date in the glob
# python /media/david/USB/Skraak/src/predict.py

from opensoundscape.torch.models.cnn import load_model
import opensoundscape

import torch
from pathlib import Path
import numpy as np
import pandas as pd

from glob import glob
import os
from datetime import date
todays_date = date.today().strftime('%Y-%m-%d')
from datetime import datetime

model = load_model('/home/david/best.model0')

# folders = glob('./*/2023-?????/')
# folders = glob('./*/*/2024-05-0?')
# folders = glob('./*/2024-10-18/')
folders =  glob('./*/2025-02-25/')
for folder in folders:
    os.chdir(folder)
    print(folder, ' start: ', datetime.now())
    # Beware, secretary island files are .wav
    field_recordings = glob('./*.[W,w][A,a][V,v]')
    scores, preds, unsafe = model.predict(
            field_recordings,
            binary_preds = 'single_target',
            overlap_fraction = 0.5,
            batch_size =  128,
            num_workers = 12)
    scores.to_csv('scores_opensoundscape-kiwi-1.0_{todays_date}.csv')
    preds.to_csv('preds_opensoundscape-kiwi-1.0_{todays_date}.csv')
    os.chdir('../..') # Be careful this matches the glob on line 284
    print(folder, ' done: ', datetime.now())
    print()
    print()
=#
#=Kahurangi
folders =  Glob.glob('./*/')
for folder in folders:
    os.chdir(folder)
    print(folder, ' start: ', datetime.now())
    # Beware, secretary island files are .wav
    field_recordings = Glob.glob('./*.[W,w][A,a][V,v]')
    scores, preds, unsafe = model.predict(
            field_recordings,
            binary_preds = 'single_target',
            overlap_fraction = 0.5,
            batch_size =  128,
            num_workers = 12)
    scores.to_csv("scores-2024-10-21.csv")
    preds.to_csv("preds-2024-10-21.csv")
    os.chdir('./..') # Be careful this matches the glob on line 284
    print(folder, ' done: ', datetime.now())
    print()
    print()
=#