modify
This commit is contained in:
		
							parent
							
								
									2bd4f09891
								
							
						
					
					
						commit
						86b50fb075
					
				
							
								
								
									
										2
									
								
								model.py
								
								
								
								
							
							
						
						
									
										2
									
								
								model.py
								
								
								
								
							| 
						 | 
				
			
			@ -245,7 +245,7 @@ class FastRCNN(nn.Module):
 | 
			
		|||
            # filter by threshold
 | 
			
		||||
            cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
 | 
			
		||||
            print(cls_prob.shape)
 | 
			
		||||
            pos_mask = cls_prob[:, 1] > thresh
 | 
			
		||||
            pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes)
 | 
			
		||||
            print(pos_mask.shape)
 | 
			
		||||
            proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask]
 | 
			
		||||
            print(proposals_img.shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue