tensor_methods

Tensor Methods for DS and SC
git clone git://popovic.xyz/tensor_methods.git
Log | Files | Refs

unf-aprx.jl (1569B)


      1 #/usr/bin/julia
      2 
      3 using LinearAlgebra
      4 using Plots
      5 using Distributions
      6 using LaTeXStrings
      7 using Random
      8 using TSVD: tsvd # todo: implement
      9 using TensorToolbox: tenmat, ttm, contract # todo:implement
     10 using Plots.Measures
     11 
     12 function σ_unfold(C, d)
     13     Σ_s = []
     14     for k=1:d
     15         U, Σ, V = svd(tenmat(C, k))
     16         append!(Σ_s, [Σ])
     17     end
     18     return Σ_s
     19 end
     20 
     21 function ttmps_unfold(A, k)
     22     n = size(A)
     23     d = length(n)
     24     C = copy(A)
     25     A_k = reshape(C, (prod([n[i] for i in 1:k]), prod([n[i] for i in k+1:d])))
     26     return A_k
     27 end
     28 
     29 function tucker_unfold(A, k)
     30     n = size(A)
     31     d = length(n)
     32     C = copy(A)
     33     C_perm = permutedims(C, ([i for i=1:d if i!=k]..., k))
     34     A_k = reshape(C_perm, (prod([n[i] for i=1:d if i!=k]), n[k]))
     35     return A_k
     36 end
     37 
     38 ϵ_s = [1/(10^(j*2)) for j=1:5] # computer not goode enough for j = 6
     39 function rank_approx(C, method,ϵ_s=ϵ_s)
     40     d = length(size(C))
     41     ϵ_jk = []; r_jk = []; σ_jk = []
     42     for k=1:d
     43         C_k = method(C, k)
     44         for (j, ϵ_j) in enumerate(ϵ_s)
     45             for r=1:rank(C_k)
     46                 U_hat, Σ_hat, V_hat = tsvd(C_k, r)
     47                 C_k_hat = U_hat * Diagonal(Σ_hat) * transpose(V_hat)
     48                 if norm(C_k_hat-C_k)/norm(C_k) <= ϵ_j
     49                     append!(ϵ_jk, norm(C_k_hat-C_k)/norm(C_k))
     50                     append!(r_jk, r)
     51                     append!(σ_jk, [Σ_hat])
     52                     break
     53                 end
     54             end
     55         end
     56     end
     57     ndims = (length(ϵ_s), d)
     58     return reshape(r_jk, ndims), reshape(ϵ_jk, ndims), reshape(σ_jk, ndims)
     59 end