main.jl (1105B)
1 using LinearAlgebra 2 using LaTeXStrings 3 using Distributions 4 using Plots 5 6 include("../../sesh2/src/functions.jl") 7 include("../../sesh3/src/functions.jl") 8 9 function errplot(x, phi, d_phi_2, n) 10 p = plot(x, phi, 11 lw=3, 12 titlefontsize=20, 13 xlabelfontsize=14, 14 ylabelfontsize=14, 15 dpi=300, 16 grid=false, 17 size=(500, 400)) 18 plot!(p, x, d_phi_2, 19 lw=3, 20 title="n=$n", 21 label=L"\nabla \frac{1}{2} \phi^2") 22 savefig(p, "./plots/err_$n.png") 23 end 24 25 function main() 26 nr = [(2, 7), (3, 23), (4, 49)] 27 28 for (n, r) in nr 29 V = [reshape(vcat([normalize(rand(-1:1, n^2)) for i=1:r]...), (n^2, r)) for _=1:3] # guess 30 T, U = multiplication_tensor(n) # given n^3 CPD 31 V_hat, phi, d_phi_2 = cp_als(U, V, 10000) # CP-ALS + err 32 x = collect(1:length(phi)) # for plot 33 34 errplot(x, phi, d_phi_2, n) 35 end 36 end 37 38 main()