# Train.jl export train #beware Flux.train! is not Skraak.train import Base: length, getindex import MLBase using CUDA, Dates, Images, Flux, Glob, JLD2, Noise using Random: shuffle!, seed! using Metalhead: ResNet #= function train( model_name::String, train_epochs::Int64, images::Vector{String}, pretrain::Model=true, train_test_split::Float64 = 0.8, batch_size::Int64 = 64, ) Note: Dont forget temp env, julia -t 4 Assumes 224x224 pixel RGB images as png's Saves jld2's in current directory Use like: using Skraak, Glob images = Glob.glob("kiwi_set*/*/[N,K]/*.png") #11699814-element Vector{String} model = "/media/david/SSD2/PrimaryDataset/model_K1-9_original_set_CPU_epoch-7-0.9924-2024-03-05.jld2" train("K1-10_total_set_no_augumentation", 2, images, model, 0.97, 64) images = Glob.glob("*/[D,F,M,N]/*.png") #from SSD2/Clips model = "/media/david/SSD2/PrimaryDataset/model_K1-5_CPU_epoch-6-0.9795-2023-12-16.jld2" train("DFMN1-5", 20, images, model) =# const LABELTOINDEX::Dict{String,Int32} = Dict() Model = Union{Bool,String} function train( model_name::String, train_epochs::Int64, images::Vector{String}, #glob_pattern::String = "*/*.png" pretrain::Model = true, train_test_split::Float64 = 0.8, batch_size::Int64 = 64, ) epochs = 1:train_epochs #images = Glob.glob(glob_pattern) #|> shuffle! |> x -> x[1:640] @assert !isempty(images) "No png images found" @info "$(length(images)) images in dataset" label_to_index = labels_to_dict(images) register_label_to_index!(label_to_index) @info "Text labels translate to: " label_to_index classes = length(label_to_index) @assert classes >= 2 "At least 2 label classes are required, for example: kiwi, not_kiwi" @info "$classes classes in dataset" @info "Device: $device" ceiling = seil(length(images), batch_size) train_test_index = train_test_idx(ceiling, batch_size, train_test_split) train, train_sample, test = process_data(images, train_test_index, ceiling, batch_size) @info "Made data loaders" model = load_model(pretrain, classes) @info "Loaded model" opt = Flux.setup(Flux.Optimisers.Adam(1e-5), model) @info "Setup optimiser" @info "Training for $epochs epochs: " now() training_loop!( model, opt, train, train_sample, test, epochs, model_name, classes, label_to_index, ) @info "Finished $(last(epochs)) epochs: " now() end struct ImageContainer{T<:Vector} img::T end struct ValidationImageContainer{T<:Vector} img::T end Container = Union{ImageContainer,ValidationImageContainer} function seil(n::Int, batch_size::Int) return n ÷ batch_size * batch_size end function train_test_idx(ceiling::Int, batch_size::Int, train_test_split::Float64)::Int t = #! format: off ceiling ÷ batch_size * train_test_split |> round |> x -> x * batch_size |> x -> convert(Int, x) #! format: on end function labels_to_dict(list::Vector{String})::Dict{String,Int32} l = #! format: off map(x -> split(x, "/")[end-1], list) |> unique |> sort |> x -> zip(x, 1:length(x)) |> Dict #! format: on return l end """ register_label_to_index!(label_to_index::Dict{String,Int32}) This will replace the content of the global variable LABELTOINDEX with the content intended by the caller. Thanks algunion https://discourse.julialang.org/t/dataloader-scope-troubles/105207/4 """ function register_label_to_index!(label_to_index::Dict{String,Int32}) empty!(LABELTOINDEX) merge!(LABELTOINDEX, label_to_index) end device = CUDA.functional() ? gpu : cpu function process_data(array_of_file_names, train_test_index, ceiling, batch_size) seed!(1234) images = shuffle!(array_of_file_names) train = ImageContainer(images[1:train_test_index]) |> x -> make_dataloader(x, batch_size) train_sample = ValidationImageContainer(images[1:(ceiling-train_test_index)]) |> x -> make_dataloader(x, batch_size) test = ValidationImageContainer(images[train_test_index+1:ceiling]) |> x -> make_dataloader(x, batch_size) return train, train_sample, test end length(data::ImageContainer) = length(data.img) length(data::ValidationImageContainer) = length(data.img) function getindex(data::ImageContainer{Vector{String}}, index::Int) path = data.img[index] img = #! format: off Images.load(path) |> #x -> Images.imresize(x, 224, 224) |> #x -> Images.RGB.(x) |> x -> Noise.add_gauss(x, (rand() * 0.2)) |> x -> apply_mask!(x, 3, 3, 12) |> x -> collect(channelview(float32.(x))) |> x -> permutedims(x, (3, 2, 1)) #! format: on y = LABELTOINDEX[(split(path, "/")[end-1])] return img, y end function getindex(data::ValidationImageContainer{Vector{String}}, index::Int) path = data.img[index] img = #! format: off Images.load(path) |> #x -> Images.imresize(x, 224, 224) |> #x -> Images.RGB.(x) |> x -> collect(channelview(float32.(x))) |> x -> permutedims(x, (3, 2, 1)) #! format: on y = LABELTOINDEX[(split(path, "/")[end-1])] return img, y end # assumes 224px square images function apply_mask!( img::Array{RGB{N0f8},2}, max_number::Int = 3, min_size::Int = 3, max_size::Int = 22, ) # horizontal for range in get_random_ranges(max_number, min_size, max_size) img[range, :] .= RGB{N0f8}(0.7, 0.7, 0.7) end # vertical for range in get_random_ranges(max_number, min_size, max_size) img[:, range] .= RGB{N0f8}(0.7, 0.7, 0.7) end return img end # assumes 224px square images function get_random_ranges(max_number::Int, min_size::Int, max_size::Int) number = rand(0:max_number) ranges = [] while length(ranges) < number start = rand(1:224) size = rand(min_size:max_size) if start + size > 224 continue end push!(ranges, start:start+size) end return ranges end function make_dataloader(container::Container, batch_size::Int) data = Flux.DataLoader(container; batchsize = batch_size, collate = true, parallel = true) device == gpu ? data = CuIterator(data) : nothing return data end # see load_model() from predict, and below function load_model(pretrain::Bool, classes::Int64) fst = Metalhead.ResNet(18, pretrain = pretrain).layers lst = Flux.Chain(AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => classes)) model = Flux.Chain(fst[1], lst) |> device return model end #If model classes == desired classes I don't empty the last layer #That means that I can just train from where I left off for new data, DFMN model #Could be a gotcha if I want to train a different 4 class model, no need for a switch just yet function load_model(model_path::String, classes::Int64) model_state = JLD2.load(model_path, "model_state") model_classes = length(model_state[1][2][1][3][2]) f = Metalhead.ResNet(18, pretrain = false).layers l = Flux.Chain(AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => model_classes)) m = Flux.Chain(f[1], l) Flux.loadmodel!(m, model_state) if classes == model_classes model = m |> device else fst = m.layers lst = Flux.Chain(AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => classes)) model = Flux.Chain(fst[1], lst) |> device end return model end function evaluate(m, d, c) good = 0 count = 0 pred = Int64[] actual = Int64[] for (x, y) in d p = Flux.onecold(m(x)) good += sum(p .== y) count += length(y) append!(pred, p) append!(actual, y) end accuracy = round(good / count, digits = 4) confusion_matrix = MLBase.confusmat(c, actual, pred) #freqtable(DataFrames.DataFrame(targets = actual, predicts = pred), :targets, :predicts) #roc=MLBase.roc(actual, pred, 100) #f1=MLBase.f1score(roc) return accuracy, confusion_matrix #, roc, f1 end function train_epoch!(model; opt, train, classes) Flux.train!(model, train, opt) do m, x, y Flux.Losses.logitcrossentropy(m(x), Flux.onehotbatch(y, 1:classes)) end end function dict_to_text_file(dict, model_name) text = "" for (key, value) in dict text = text * "$(key) => $(value)\n" end open("labels_$(model_name)-$(today()).txt", "w") do file write(file, text) end @info "Saved labels to file for future reference" end function training_loop!( model, opt, train, train_sample, test, epochs::UnitRange{Int64}, model_name::String, classes, label_to_index, ) @time eval, vcm = evaluate(model, test, classes) @info "warm up accuracy" accuracy = eval @info "warm up confusion matrix" vcm a = 0 for epoch in epochs println("") @info "Starting Epoch: $epoch" epoch == 1 && dict_to_text_file(label_to_index, model_name) @time train_epoch!(model; opt, train, classes) @time train_accuracy, train_confusion_matrix = evaluate(model, train_sample, classes) @info "Epoch: $epoch" @info "train" accuracy = train_accuracy @info "train" train_confusion_matrix @time test_accuracy, test_confusion_matrix = evaluate(model, test, classes) @info "test" accuracy = test_accuracy @info "test" test_confusion_matrix # number kiwi guessed right, assumes kiwi=1, not=2 (alphabetical) #test_confusion_matrix[1,1] > a && begin #a = test_confusion_matrix[1,1] let _model = cpu(model) jldsave( "/media/david/SSD2/model_$(model_name)_CPU_epoch-$epoch-$test_accuracy-$(today()).jld2"; model_state = Flux.state(_model), ) @info "Saved a best_model" end #end end end