+ function predict_avianz_folder(data_files::Vector{String}, model, folder::String, labels::Dict)
+ @info "$(length(data_files)) .data files in $folder"
+ for data_file in data_files
+ predict_avianz_file(data_file, model, labels)
+ end
+ end
+
+ function predict_avianz_file(data_file::String, model, labels::Dict)
+ @info "Processing: $data_file"
+ data = JSON.parsefile(data_file)
+
+ wav_file = data_file[1:end-5] # strip .data suffix
+ if !isfile(wav_file)
+ @warn "Audio file not found: $wav_file, skipping"
+ return
+ end
+
+ signal, freq = load_audio_file(wav_file)
+ if freq != 8000
+ signal, freq = _resample_to_8000hz(signal, freq)
+ end
+ f = convert(Int, freq)
+
+ # Collect images and references to Kiwi segment labels
+ raw_images = []
+ kiwi_labels = [] # references to the label dicts to update
+
+ # Elements 2..N are segments (index 1 is metadata in AviaNZ format)
+ for i in 2:length(data)
+ segment = data[i]
+ # Each segment is [start_time, end_time, low_freq, high_freq, [labels...]]
+ segment_labels = segment[5]
+ for label in segment_labels
+ if isa(label, Dict) && haskey(label, "species") && label["species"] == "Kiwi"
+ start_time = segment[1]
+ end_time = segment[2]
+ start_sample = max(1, round(Int, start_time * f) + 1)
+ end_sample = min(size(signal, 1), round(Int, end_time * f))
+ if end_sample <= start_sample
+ @warn "Empty segment at $start_time-$end_time in $data_file, skipping"
+ continue
+ end
+ sample = signal[start_sample:end_sample, 1]
+ image = _get_image_from_sample(sample, f)
+ push!(raw_images, image)
+ push!(kiwi_labels, label)
+ end
+ end
+ end
+
+ if isempty(raw_images)
+ @info "No Kiwi segments in $data_file"
+ return
+ end
+
+ @info "$(length(raw_images)) Kiwi segments in $data_file"
+
+ # Create DataLoader and predict
+ n_samples = length(raw_images)
+ loader = avianz_loader(raw_images, n_samples)
+ preds = []
+ for x in loader
+ p = Flux.onecold(model(x))
+ append!(preds, p)
+ end
+
+ # Update species/calltype in each Kiwi segment's label
+ for (i, label) in enumerate(kiwi_labels)
+ pred = labels[preds[i]]
+ if pred == "Don't Know"
+ label["species"] = "Don't Know"
+ delete!(label, "calltype")
+ else
+ label["calltype"] = pred
+ end
+ end
+
+ # Write modified JSON back to .data file
+ open(data_file, "w") do io
+ JSON.print(io, data, 2)
+ end
+ @info "Updated $data_file"
+ end
+
+ function avianz_loader(raw_images::Vector, n_samples::Int)
+ #! format: off
+ processed = map(raw_images) do img
+ colorview(RGB, permutedims(img, (3, 1, 2))) |>
+ x -> Images.RGB.(x) |>
+ x -> collect(channelview(float32.(x))) |>
+ x -> permutedims(x, (3, 2, 1))
+ end
+ #! format: on
+ images = cat(processed..., dims=4)
+ loader = Flux.DataLoader(images, batchsize=n_samples, shuffle=false)
+ device == gpu ? loader = CuIterator(loader) : nothing
+ return loader
+ end
+