clc; close; clear; %% Generate data [X, y, n] = make3moons(2500, sqrt(0.02), 100); k = max(y(:)); classes = y; NN = 8; % number of nearest neighbours to consider M = 17; % local scale parameter A = create_graph_zmp(X, NN, M, 'avg'); num_vertices = size(A, 1); num_edges = sum(sum(double(A>0)))/2; D = adj_to_gradient(A,num_vertices,num_edges); %% distribute labels percent_labeled = 5; weights = 10; num_labeled = max(1, round((percent_labeled/(100*k)) * num_vertices)); for l=1:k indices = find(classes == l); prev_indices = indices; idx = (l-1)*num_labeled; idx_label(idx+1:idx+num_labeled) = indices(1:num_labeled); end num_labeled = num_labeled * k; idx_label = idx_label(1:num_labeled)'; % TODO: construct f. f = zeros(num_vertices,k); for i = 1:length(idx_label) f(idx_label(i),classes(idx_label(i))) = -1; end f = f *weights; figure(1) hold on; title('labeled data'); plot(X(:,1),X(:,2),'o','color', [211,211,211]/255); plot(X(idx_label,1),X(idx_label,2),'or'); %% initialize general parameters gap_tol = 1e-3; max_iters=1500; %% PDHG u = zeros(n,k); u_tilde = zeros(n,k); p = zeros(num_edges, k); normD = normest(D,1e-6); s = 1.01*normD; t = 1.01*normD; primal_energy_pdhg = zeros(1,max_iters); dual_energy_pdhg = zeros(1,max_iters); for it=1:max_iters for i=1:k u_tilde(:, i) = u(:,i) - 1/s * (D' * p(:,i) + f(:,i)); end u_prev = u; u = projSimplex(u_tilde); for i=1:k p(:,i) = min(max( p(:,i) + D * (2*u(:,i) - u_prev(:,i))/t, -1), 1); end primal_energy_pdhg(it) = sum(reshape(u .* f, [], 1)) + sum(reshape(abs(D * u), [], 1)); dual_energy_pdhg(it) = sum(max(-D'*p-f,[],2)); gap = primal_energy_pdhg(it)+dual_energy_pdhg(it); fprintf('%d %.6f\n', it, gap); end %% ADMM u = zeros(n,k); v1 = zeros(n,k); v2 = zeros(num_edges,k); p1 = zeros(n,k); p2 = zeros(num_edges,k); rho = 1; primal_energy_admm = zeros(1,max_iters); dual_energy_admm = zeros(1,max_iters); for it=1:max_iters % primal update 1 for i=1:k rhs = v1(:,i)+D'*v2(:,i)-(p1(:,i)+D'*p2(:,i))/rho; u(:,i) = (speye(size(D,2))+D'*D)\rhs; end % primal update proj_arg = u + p1 / rho - f/rho; v1 = projSimplex(proj_arg); for i=1:k prox_arg = D*u(:,i) + p2(:,i) / rho; v2(:,i) = prox_l1(prox_arg, 1 / rho); end % Lagrange multiplier update p1 = p1 + rho * (u - v1); %for i=1:k p2 = p2 + rho * (D * u - v2); %end primal_energy_admm(it) = sum(reshape(v1 .* f, [], 1)) + sum(reshape(abs(D * v1), [], 1)); dual_energy_admm(it) = sum(max(-D'*p2-f,[],2)); gap = primal_energy_admm(it)+dual_energy_admm(it); fprintf('%d %.6f\n', it, primal_energy_admm(it)); [~,results] = max(u,[],2); fprintf('accuracy:%.4f\n', sum(double(results-y==0))/n); end %% [~,results] = max(u,[],2); figure(2) hold on; strings = ['or','ob','og']; for i = 1:k index = find(results==i); plot(X(index,1),X(index,2),strings(2*i-1:2*i)); end fprintf('accuracy:%.4f\n', sum(double(results-y==0))/n);