plot_functions.py 3.64 KB
Newer Older
Jonathan Juhl's avatar
all  
Jonathan Juhl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

def plot_function(self,Models,y,logwriter,logfile,strategy,average_dis,distribution,l_list,ite,validate_interval,prefix):
        
    
            image = tf.cast(image,self.precision)

            features,feat = strategy.run(self.predict_encoder,args=(image,))
            
            feat.numpy().astype(np.float32)
            features.numpy().astype(np.float32)
            classes,invers_counts = np.unique(features,axis=0,return_inverse=True)
        
    

            self.gui_plot(int(ite/validate_interval),feat,np.asarray(l_list)[1:],np.asarray(average_dis),distribution,colors=y,prefix=prefix) # plot the data with label colors

    

            if bool(self.verbose) > 0: 


                
                c = np.arange(classes.shape[0])[invers_counts]
                acc = np.round(accuracy_score(y, c), 5) # accuracy
                nmi = np.round(normalized_mutual_info_score(y, c), 5) # normalized mutual information
                ari = np.round(adjusted_rand_score(y, c), 5) # 
                print('Iter %d: of %d acc=%.5f, nmi=%.5f, loss=%.5f, dis_avg=%.5f' % (ite,self.steps, acc, nmi, l_list[-1], average_dis[-1])) # print into console , specific statistics
                logdict = dict(iter=ite, acc=acc, nmi=nmi, loss=l_list[-1], dis_avg=average_dis[-1]) # log it into logfile.
                logwriter.writerow(logdict) # write row
                logfile.flush()# flush into file to avoid memory consumption
            else:
                    print('Iter %d: of %d ,loss=%.5f, dis_avg=%.5f' % (ite,self.steps, l_list[-1], average_dis[-1])) # print information 
                    logdict = dict(iter=ite, loss=l_list[-1], dis_avg=average_dis[-1]) # log it into logfile.
                    logwriter.writerow(logdict)


def plot_gan(self,Models,logwriter,logfile,strategy,d_list,l_list,ite,validate_interval,prefix,generator):
            image = tf.cast(image,self.precision)
            models = strategy.run(generator,args=(Models,))
            view_models =  self.project_a_model(models)
            self.gui_plot(int(ite/validate_interval),view_models,np.asarray(d_list)[1:],np.asarray(l_list)[1:],models,prefix=prefix) # plot the data with label colors
            print('Iter %d: of %d loss_generator=%.5f, loss_descriminator=%.5f' % (ite,self.steps)) # print into console , specific statistics
            logdict = dict(iter=ite,  loss=l_list[-1]) # log it into logfile.
            logwriter.writerow(logdict) # write row
            logfile.flush()# flush into file to avoid memory consumption
        

def gui_plot(self,image_real,loss_generator,loss_discriminator,model_similarity,prefix=None):
    
    fig, ((ax1, ax2),(ax3, ax4)) = plt.subplots(nrows=2,ncols=2)
    s = image_real.shape[1]
    top_r = np.reshape(image_real,(2*s,2*s))
    top_f = np.reshape(image_fake,(2*s,2*s))

    
    pca_all = TSNE(2).fit_transform(features)
    plt.tight_layout()
    ax1.imshow(top_r)
    ax1.set_title("Projections from Orthogonal Vectors")

    ax2.plot(np.arange(loss_generator.shape[0])*self.validate_interval, loss_generator, 'go--', linewidth=1, markersize=8)
    ax3.plot(np.arange(loss_discriminator.shape[0])*self.validate_interval, loss_discriminator, 'ro--', linewidth=1, markersize=8)

    ax2.set_ylabel('loss', fontsize=10)
    ax2.set_title("Generative Loss")


    ax3.imshow(top_f)
    ax3.set_title("Discriminative  Loss")

    ax4.imshow(np.dot(model_similarity,model_similarity.T))
    ax4.set_ylabel('Cross correlation', fontsize=10)
    ax4.set_title("Model similarity")
    
    plt.savefig(join(self.results,'%s_output.png' %prefix))

    plt.clf()
    plt.close('all')