# Predict.jl
export predict
export get_images_from_audio
using WAV,
DSP, Images, ThreadsX, Dates, DataFrames, CSV, Flux, CUDA, Metalhead, JLD2, JSON, FLAC, Glob, PerceptualColourMaps
import Base: length, getindex
##Dependency, duplicated from Utility
function _resample_to_16000hz(signal, freq)
signal = DSP.resample(signal, 16000.0f0 / freq; dims=1)
freq = 16000
return signal, freq
end
function _resample_to_8000hz(signal, freq)
signal = DSP.resample(signal, 8000.0f0 / freq; dims=1)
freq = 8000
return signal, freq
end
##Dependency, duplicated from Clips
function _get_image_from_sample(sample, f) #sample::Vector{Float64}
S = DSP.spectrogram(sample, 400, 2; fs=convert(Int, f))
i = S.power
if minimum(i) == 0.0
l = i |> vec |> unique |> sort
replace!(i, 0.0 => l[2])
end
image =
#! format: off
DSP.pow2db.(i) |>
x -> x .+ abs(minimum(x)) |>
x -> x ./ maximum(x) |>
x -> reverse(x, dims = 1) |>
x -> PerceptualColourMaps.applycolourmap(x, cmap("L4")) |>
#x -> RGB.(x) |>
x -> imresize(x, 224, 224) |>
x -> Float32.(x)
#! format: on
return image
end
"""
predict(glob_pattern::String, model::String)
This function takes a glob pattern for folders (or a vector of folders) to run over, and a model path. It saves results in a csv in each folder, similar to opensoundscape
Args:
• glob pattern (folder/) or a vector of folders
• model path
Returns: Nothing - This function saves csv files.
I use this function to find kiwi from new data gathered on a trip. And to predict D/F/M/N for images clipped from primary detections.
It works on both audio (wav or flac) and png images.
Note:
From Pomona-3/Pomona-3/
julia +1.10 -t 4
Dont forget temp environment: ] activate --temp
Use like:
using SkraakML, Glob
glob_pattern = "*/*/"
model = "/media/david/SSD2/PrimaryDataset/model_K1-9_original_set_CPU_epoch-7-0.9924-2024-03-05.jld2"
glob_pattern = "Clips_2024-10-21/"
model = "/media/david/SSD1/Clips/model_DFMN1-5_CPU_epoch-18-0.9132-2024-01-29.jld2"
glob_pattern = "Clips_Kahurangi_2024-10-25"
glob_pattern = "Clips_Cobb_2024-10-25/"
glob_pattern = "Clips_2024-10-28_MT6"
glob_pattern = "Pomona-*/Pomona/Clips_2024-11-27/"
model = "/media/david/SSD2/Secondary_Models/DFMN_Inge/model_DFMN1-5_CPU_epoch-9-0.9737-2024-10-25.jld2"
model = "/media/david/SSD2/DFMN_Pomona/model_DFMN1-5_Pomona2_CPU_epoch-2-0.9459-2025-01-19.jld2"
model = "/media/david/SSD2/Secondary_Models/LSK/model_GSK_LSK_DFM_FT_IngeDFMN_1-5_1-0_CPU_epoch-9-0.9745-2025-01-13.jld2"
model = "/media/david/SSD2/Secondary_Models/DFMN_Pomona/model_DFMN1-5_Pomona3_CPU_epoch-18-0.9785-2025-03-02.jld2"
predict(glob_pattern, model)
folders=glob("NEW*/*/Clips_*/Kiwi/")
predict(folders, model)
"""
function predict(glob_pattern::String, model::String)
model = load_model_pred(model) |> device
folders = Glob.glob(glob_pattern)
@info "Folders: $folders"
for folder in folders
@info "Working on: $folder"
predict_folder(folder, model)
end
end
function predict(folders::Vector{String}, model::String)
model = load_model_pred(model) |> device
@info "Folders: $folders"
for folder in folders
@info "Working on: $folder"
predict_folder(folder, model)
end
end
function predict(folders::Vector{String}, model::String, labels::Dict)
model = load_model_pred(model) |> device
@info "Folders: $folders"
for folder in folders
@info "Working on: $folder"
predict_folder(folder, model, labels)
end
end
#~~~~~ The guts ~~~~~#
# see load_model() from train, different input types
function load_model_pred(model_path::String)
model_state = JLD2.load(model_path, "model_state")
model_classes = length(model_state[1][2][1][3][2])
@info "Model classes: $model_classes"
f = Metalhead.ResNet(18, pretrain=false).layers
l = Flux.Chain(AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => model_classes))
model = Flux.Chain(f[1], l)
Flux.loadmodel!(model, model_state)
return model
end
#=
function load_bson(model_path::String)
BSON.@load model_path model
end
=#
function predict_folder(folder::String, model, labels::Dict=Dict())
data_files = Glob.glob("$folder/*.data")
if !isempty(data_files) && !isempty(labels)
predict_avianz_folder(data_files, model, folder, labels)
return
end
wav = Glob.glob("$folder/*.[W,w][A,a][V,v]")
flac = Glob.glob("$folder/*.flac")
audio_files = vcat(wav, flac) #if wav and flac both present will predict on all
png_files = Glob.glob("$folder/*.png")
#it will predict on images when both images and audio present
if isempty(png_files)
length(audio_files) > 0 ? predict_audio_folder(audio_files, model, folder) :
@info "No png, flac, wav, WAV files present in $folder"
else
predict_image_folder(png_files, model, folder)
end
end
device = CUDA.functional() ? gpu : cpu
# Predict from png images
struct PredictImageContainer{T<:Vector}
img::T
end
length(data::PredictImageContainer) = length(data.img)
function getindex(data::PredictImageContainer{Vector{String}}, idx::Int)
path = data.img[idx]
img =
#! format: off
Images.load(path) |>
x -> Images.imresize(x, 224, 224)|>
x -> Images.RGB.(x) |>
x -> collect(channelview(float32.(x))) |>
x -> permutedims(x, (3, 2, 1))
#! format: on
return img, path
end
function predict_image_folder(png_files::Vector{String}, model, folder::String)
l = length(png_files)
@assert (l > 0) "No png files present in $folder"
@info "$(l) png_files in $folder"
save_path = "$folder/preds-$(today()).csv"
loader = png_loader(png_files)
@time preds, files = predict_pngs(model, loader)
f = split.(files, "/") |> x -> last.(x)
df = DataFrames.DataFrame(file=f, label=preds)
CSV.write("$save_path", df)
end
function png_loader(png_files::Vector{String})
loader = Flux.DataLoader(
PredictImageContainer(png_files);
batchsize=64,
collate=true,
parallel=true,
)
device == gpu ? loader = CuIterator(loader) : nothing
return loader
end
function predict_pngs(m, d)
@info "Predicting..."
pred = []
path = []
for (x, pth) in d
p = Flux.onecold(m(x))
append!(pred, p)
append!(path, pth)
end
return pred, path
end
# Predict from audio files
function predict_audio_folder(audio_files::Vector{String}, model, folder::String)
l = length(audio_files)
@assert (l > 0) "No wav or flac audio files present in $folder"
@info "$(l) audio_files in $folder"
df = DataFrames.DataFrame(
file=String[],
start_time=Float64[],
end_time=Float64[],
label=Int[],
)
save_path = "$folder/preds-$(today()).csv"
CSV.write("$save_path", df)
for file in audio_files
df = predict_audio_file(file, model)
CSV.write("$save_path", df, append=true)
end
end
function predict_audio_file(file::String, model)
#check form of opensoundscape preds.csv and needed by my make_clips
@info "File: $file"
@time data = audio_loader(file)
pred = []
time = []
@time for (x, t) in data
p = Flux.onecold(model(x))
append!(pred, p)
append!(time, t)
end
f = (repeat(["$file"], length(time)))
df = DataFrames.DataFrame(
:file => f,
:start_time => first.(time),
:end_time => last.(time),
:label => pred,
)
sort!(df)
return df
end
function audio_loader(file::String, increment::Int=5, divisor::Int=2)
raw_images, n_samples = get_images_from_audio(file::String, increment, divisor)
images = reshape_images(raw_images, n_samples)
# Start time and end time for each 5s audio clip, in seconds relative to the start of the file.
start_time = 0:(increment/divisor):(n_samples-1)*(increment/divisor)
end_time = increment:(increment/divisor):(n_samples+1)*(increment/divisor)
time = collect(zip(start_time, end_time))
loader = Flux.DataLoader((images, time), batchsize=n_samples, shuffle=false)
device == gpu ? loader = CuIterator(loader) : nothing #check this works with gpu
return loader
end
function reshape_images(raw_images, n_samples)
images =
#! format: off
hcat(raw_images...) |>
x -> reshape(x, (224, 224, 3, n_samples))
#! format: on
return images
end
#= not needed
function get_image_for_inference(sample, f)
image =
#! format: off
_get_image_from_sample(sample, f) |>
# x -> collect(channelview(float32.(x))) |>
x -> permutedims(x, (3, 2, 1))
#! format: on
return image
end
=#
# need to change divisor to a overlap fraction, chech interaction with audioloader()
# if divisor is 0, then no overlap atm
function get_images_from_audio(file::String, increment::Int=5, divisor::Int=2) #5s sample, 2.5s hop
signal, freq = load_audio_file(file)
if freq > 16000
signal, freq = _resample_to_16000hz(signal, freq)
end
f = convert(Int, freq)
inc = increment * f
#hop = f * increment ÷ divisor #need guarunteed Int, maybe not anymore, refactor
hop = 0 #f * increment / divisor |> x -> x == Inf ? 0 : trunc(Int, x)
split_signal = DSP.arraysplit(signal[:, 1], inc, hop)
raw_images = ThreadsX.map(x -> _get_image_from_sample(x, f), split_signal)
n_samples = length(raw_images)
return raw_images, n_samples
end
function load_audio_file(file::String)
ext = split(file, ".")[end]
@assert ext in ["WAV", "wav", "flac"] "Unsupported audio file type, requires wav or flac."
if ext in ["WAV", "wav"]
signal, freq = WAV.wavread(file)
else
signal, freq = load(file)
end
@assert !isempty(signal[:, 1]) "$file seems to be empty, could it be corrupted?\nYou could delete it, or replace it with a known\ngood version from SD card or backup."
return signal, freq
end
############### AviaNZ .data file support ################
function predict_avianz_folder(data_files::Vector{String}, model, folder::String, labels::Dict)
@info "$(length(data_files)) .data files in $folder"
for data_file in data_files
predict_avianz_file(data_file, model, labels)
end
end
function predict_avianz_file(data_file::String, model, labels::Dict)
@info "Processing: $data_file"
data = JSON.parsefile(data_file)
wav_file = data_file[1:end-5] # strip .data suffix
if !isfile(wav_file)
@warn "Audio file not found: $wav_file, skipping"
return
end
signal, freq = load_audio_file(wav_file)
if freq != 8000
signal, freq = _resample_to_8000hz(signal, freq)
end
f = convert(Int, freq)
# Collect images and references to Kiwi segment labels
raw_images = []
kiwi_labels = [] # references to the label dicts to update
# Elements 2..N are segments (index 1 is metadata in AviaNZ format)
for i in 2:length(data)
segment = data[i]
# Each segment is [start_time, end_time, low_freq, high_freq, [labels...]]
segment_labels = segment[5]
for label in segment_labels
if isa(label, Dict) && haskey(label, "species") && label["species"] == "Kiwi"
start_time = segment[1]
end_time = segment[2]
start_sample = max(1, round(Int, start_time * f) + 1)
end_sample = min(size(signal, 1), round(Int, end_time * f))
if end_sample <= start_sample
@warn "Empty segment at $start_time-$end_time in $data_file, skipping"
continue
end
sample = signal[start_sample:end_sample, 1]
image = _get_image_from_sample(sample, f)
push!(raw_images, image)
push!(kiwi_labels, label)
end
end
end
if isempty(raw_images)
@info "No Kiwi segments in $data_file"
return
end
@info "$(length(raw_images)) Kiwi segments in $data_file"
# Create DataLoader and predict
n_samples = length(raw_images)
loader = avianz_loader(raw_images, n_samples)
preds = []
for x in loader
p = Flux.onecold(model(x))
append!(preds, p)
end
# Update species/calltype in each Kiwi segment's label
for (i, label) in enumerate(kiwi_labels)
pred = labels[preds[i]]
#= if pred == "Don't Know"
label["species"] = "Don't Know"
delete!(label, "calltype")
else
label["calltype"] = pred
end
=#
# For LSK/GSK
if pred == "LSK"
label["species"] = "LSK"
else
label["species"] = "GSK"
end
end
# Write modified JSON back to .data file
open(data_file, "w") do io
JSON.print(io, data, 2)
end
@info "Updated $data_file"
end
function avianz_loader(raw_images::Vector, n_samples::Int)
#! format: off
processed = map(raw_images) do img
colorview(RGB, permutedims(img, (3, 1, 2))) |>
x -> Images.RGB.(x) |>
x -> collect(channelview(float32.(x))) |>
x -> permutedims(x, (3, 2, 1))
end
#! format: on
images = cat(processed..., dims=4)
loader = Flux.DataLoader(images, batchsize=n_samples, shuffle=false)
device == gpu ? loader = CuIterator(loader) : nothing
return loader
end
############### Temp test functions — delete after testing ################
function get_one(audio_file::String, start_time::Float64, end_time::Float64)
signal, freq = load_audio_file(audio_file)
if freq != 8000
signal, freq = _resample_to_8000hz(signal, freq)
end
f = convert(Int, freq)
s = max(1, round(Int, start_time * f) + 1)
e = min(size(signal, 1), round(Int, end_time * f))
sample = signal[s:e, 1]
img = _get_image_from_sample(sample, f)
return img
end
function round_trip(img)
#! format: off
colorview(RGB, permutedims(img, (3, 1, 2))) |>
x -> Images.RGB.(x) |>
x -> collect(channelview(float32.(x))) |>
x -> permutedims(x, (3, 2, 1))
#! format: on
end
# SkraakML.compare("tx51_LISTENING_20260221_203004.WAV", 5.0, 35.0)
function compare(audio_file::String, start_time::Float64, end_time::Float64)
img = get_one(audio_file, start_time, end_time)
rt = round_trip(img)
@info "Sizes: raw=$(size(img)) roundtrip=$(size(rt))"
@info "Types: raw=$(eltype(img)) roundtrip=$(eltype(rt))"
@info "Equal: $(img == rt)"
@info "Max difference: $(maximum(abs.(img .- rt)))"
@info "Raw range: min=$(minimum(img)) max=$(maximum(img))"
@info "Roundtrip range: min=$(minimum(rt)) max=$(maximum(rt))"
@info "Values > 1.0 in raw: $(sum(img .> 1.0))"
@info "Values < 0.0 in raw: $(sum(img .< 0.0))"
# Check per-channel: does raw[:,:,1] match rt[:,:,1] or a different channel?
for i in 1:3
for j in 1:3
d = maximum(abs.(img[:, :, i] .- rt[:, :, j]))
@info "raw channel $i vs roundtrip channel $j: max diff = $d"
end
end
end
############### PYTHON Opensoundscape ################
#=
# Python 3.8.12, opensoundscape 0.7.1
# Dont forget conda activate opensoundscape
# Dont forget to modify file names and glob pattern
# Run script in Pomona-2, hard code trip date in the glob
# python /media/david/USB/Skraak/src/predict.py
from opensoundscape.torch.models.cnn import load_model
import opensoundscape
import torch
from pathlib import Path
import numpy as np
import pandas as pd
from glob import glob
import os
from datetime import date
todays_date = date.today().strftime('%Y-%m-%d')
from datetime import datetime
model = load_model('/home/david/best.model0')
# folders = glob('./*/2023-?????/')
# folders = glob('./*/*/2024-05-0?')
# folders = glob('./*/2024-10-18/')
folders = glob('./*/2025-02-25/')
for folder in folders:
os.chdir(folder)
print(folder, ' start: ', datetime.now())
# Beware, secretary island files are .wav
field_recordings = glob('./*.[W,w][A,a][V,v]')
scores, preds, unsafe = model.predict(
field_recordings,
binary_preds = 'single_target',
overlap_fraction = 0.5,
batch_size = 128,
num_workers = 12)
scores.to_csv('scores_opensoundscape-kiwi-1.0_{todays_date}.csv')
preds.to_csv('preds_opensoundscape-kiwi-1.0_{todays_date}.csv')
os.chdir('../..') # Be careful this matches the glob on line 284
print(folder, ' done: ', datetime.now())
print()
print()
=#
#=Kahurangi
folders = Glob.glob('./*/')
for folder in folders:
os.chdir(folder)
print(folder, ' start: ', datetime.now())
# Beware, secretary island files are .wav
field_recordings = Glob.glob('./*.[W,w][A,a][V,v]')
scores, preds, unsafe = model.predict(
field_recordings,
binary_preds = 'single_target',
overlap_fraction = 0.5,
batch_size = 128,
num_workers = 12)
scores.to_csv("scores-2024-10-21.csv")
preds.to_csv("preds-2024-10-21.csv")
os.chdir('./..') # Be careful this matches the glob on line 284
print(folder, ' done: ', datetime.now())
print()
print()
=#