functions.jl (1131B)
1 #/usr/bin/julia 2 3 using LinearAlgebra 4 using LaTeXStrings 5 using Distributions 6 using Plots 7 8 include("../../sesh2/src/functions.jl") 9 10 function cp_als(U, V, iter) 11 d = length(U) 12 R = size(U[1])[2] 13 r = size(V[1])[2] 14 dims = [size(U[l])[1] for l=1:d] 15 Vk = copy(V) 16 17 Fs = [eleprod([(Vk[l]' * U[l]) for l=1:d if l!=j]) for j=1:d] 18 Gs = [eleprod([(Vk[l]' * Vk[l]) for l=1:d if l!=j]) for j=1:d] 19 20 phi = [] 21 d_phi_2 = [] 22 for count=1:iter 23 for k in [collect(2:d)..., reverse(collect(1:(d-1)))...] 24 Fs[k] = eleprod([(Vk[l]' * U[l]) for l=1:d if l!=k]) 25 Gs[k] = eleprod([(Vk[l]' * Vk[l]) for l=1:d if l!=k]) 26 27 VV = kron(Gs[k], Matrix{Float64}(I, dims[k], dims[k])) 28 VU = kron(Fs[k], Matrix{Float64}(I, dims[k], dims[k])) 29 30 Vk[k] = reshape(VV\(VU * vec(U[k])), (dims[k], r)) 31 end 32 append!(phi, norm(cpd_eval(Vk) - cpd_eval(U))) 33 append!(d_phi_2, norm(diff(1/2*phi.^2))) 34 end 35 return Vk, phi, d_phi_2 36 end 37 38 39 function eleprod(A) 40 d = length(A) 41 Z = copy(A[1]) 42 for i=2:d 43 Z = Z .* A[i] 44 end 45 return Z 46 end